From eb0b1d2f6992740e3b09bf383d8c0fa442faca55 Mon Sep 17 00:00:00 2001 From: Mia Herkt Date: Tue, 29 Nov 2022 21:46:33 +0100 Subject: [PATCH] nsfw_detect: Use PyAV instead of ffmpegthumbnailer --- README.rst | 2 +- nsfw_detect.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index e84f2b1..5b34b50 100644 --- a/README.rst +++ b/README.rst @@ -56,7 +56,7 @@ neural network model. This works for images and video files and requires the following: * Caffe Python module (built for Python 3) -* ``ffmpegthumbnailer`` executable in ``$PATH`` +* `PyAV `_ Network Security Considerations diff --git a/nsfw_detect.py b/nsfw_detect.py index eddf9cb..032f7e4 100755 --- a/nsfw_detect.py +++ b/nsfw_detect.py @@ -22,11 +22,12 @@ import numpy as np import os import sys from io import BytesIO -from subprocess import run, PIPE, DEVNULL from pathlib import Path os.environ["GLOG_minloglevel"] = "2" # seriously :| import caffe +import av +av.logging.set_level(av.logging.PANIC) class NSFWDetector: def __init__(self): @@ -49,7 +50,7 @@ class NSFWDetector: self.caffe_transformer.set_channel_swap('data', (2, 1, 0)) def _compute(self, img): - image = caffe.io.load_image(BytesIO(img)) + image = caffe.io.load_image(img) H, W, _ = image.shape _, _, h, w = self.nsfw_net.blobs["data"].data.shape @@ -71,13 +72,23 @@ class NSFWDetector: def detect(self, fpath): try: - ff = run([ - "ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", - "-cpng", "-i", fpath - ], stdout=PIPE, stderr=DEVNULL, check=True) - image_data = ff.stdout + with av.open(fpath) as container: + try: container.seek(int(container.duration / 2)) + except: container.seek(0) - scores = self._compute(image_data) + frame = next(container.decode(video=0)) + + if frame.width >= frame.height: + w = 256 + h = int(frame.height * (256 / frame.width)) + else: + w = int(frame.width * (256 / frame.height)) + h = 256 + frame = frame.reformat(width=w, height=h, format="rgb24") + img = BytesIO() + frame.to_image().save(img, format="ppm") + + scores = self._compute(img) except: return -1.0