Add NSFW detection
This commit is contained in:
parent
def5d9802f
commit
7bbeb2d144
11
README.rst
11
README.rst
|
@ -25,6 +25,17 @@ now and then.
|
|||
Before running the service for the first time, run ``./fhost.py db upgrade``.
|
||||
|
||||
|
||||
NSFW Detection
|
||||
--------------
|
||||
|
||||
0x0 supports classification of NSFW content via Yahoo’s open_nsfw Caffe
|
||||
neural network model. This works for images and video files and requires
|
||||
the following:
|
||||
|
||||
* Caffe Python module (built for Python 3)
|
||||
* ``ffmpegthumbnailer`` executable in ``$PATH``
|
||||
|
||||
|
||||
FAQ
|
||||
---
|
||||
|
||||
|
|
78
fhost.py
78
fhost.py
|
@ -44,6 +44,13 @@ app.config["FHOST_MIME_BLACKLIST"] = [
|
|||
|
||||
app.config["FHOST_UPLOAD_BLACKLIST"] = "tornodes.txt"
|
||||
|
||||
app.config["NSFW_DETECT"] = True
|
||||
app.config["NSFW_THRESHOLD"] = 0.7
|
||||
|
||||
if app.config["NSFW_DETECT"]:
|
||||
from nsfw_detect import NSFWDetector
|
||||
nsfw = NSFWDetector()
|
||||
|
||||
try:
|
||||
mimedetect = Magic(mime=True, mime_encoding=False)
|
||||
except:
|
||||
|
@ -72,6 +79,9 @@ class URL(db.Model):
|
|||
def getname(self):
|
||||
return su.enbase(self.id, 1)
|
||||
|
||||
def geturl(self):
|
||||
return url_for("get", path=self.getname(), _external=True) + "\n"
|
||||
|
||||
class File(db.Model):
|
||||
id = db.Column(db.Integer, primary_key = True)
|
||||
sha256 = db.Column(db.String, unique = True)
|
||||
|
@ -79,23 +89,29 @@ class File(db.Model):
|
|||
mime = db.Column(db.UnicodeText)
|
||||
addr = db.Column(db.UnicodeText)
|
||||
removed = db.Column(db.Boolean, default=False)
|
||||
nsfw_score = db.Column(db.Float)
|
||||
|
||||
def __init__(self, sha256, ext, mime, addr):
|
||||
def __init__(self, sha256, ext, mime, addr, nsfw_score):
|
||||
self.sha256 = sha256
|
||||
self.ext = ext
|
||||
self.mime = mime
|
||||
self.addr = addr
|
||||
self.nsfw_score = nsfw_score
|
||||
|
||||
def getname(self):
|
||||
return u"{0}{1}".format(su.enbase(self.id, 1), self.ext)
|
||||
|
||||
def geturl(self):
|
||||
n = self.getname()
|
||||
|
||||
if self.nsfw_score and self.nsfw_score > app.config["NSFW_THRESHOLD"]:
|
||||
return url_for("get", path=n, _external=True, _anchor="nsfw") + "\n"
|
||||
else:
|
||||
return url_for("get", path=n, _external=True) + "\n"
|
||||
|
||||
def getpath(fn):
|
||||
return os.path.join(app.config["FHOST_STORAGE_PATH"], fn)
|
||||
|
||||
def geturl(p):
|
||||
return url_for("get", path=p, _external=True) + "\n"
|
||||
|
||||
def fhost_url(scheme=None):
|
||||
if not scheme:
|
||||
return url_for(".fhost", _external=True).rstrip("/")
|
||||
|
@ -115,13 +131,13 @@ def shorten(url):
|
|||
existing = URL.query.filter_by(url=url).first()
|
||||
|
||||
if existing:
|
||||
return geturl(existing.getname())
|
||||
return existing.geturl()
|
||||
else:
|
||||
u = URL(url)
|
||||
db.session.add(u)
|
||||
db.session.commit()
|
||||
|
||||
return geturl(u.getname())
|
||||
return u.geturl()
|
||||
|
||||
def in_upload_bl(addr):
|
||||
if os.path.isfile(app.config["FHOST_UPLOAD_BLACKLIST"]):
|
||||
|
@ -152,11 +168,15 @@ def store_file(f, addr):
|
|||
with open(epath, "wb") as of:
|
||||
of.write(data)
|
||||
|
||||
if existing.nsfw_score == None:
|
||||
if app.config["NSFW_DETECT"]:
|
||||
existing.nsfw_score = nsfw.detect(epath)
|
||||
|
||||
os.utime(epath, None)
|
||||
existing.addr = addr
|
||||
db.session.commit()
|
||||
|
||||
return geturl(existing.getname())
|
||||
return existing.geturl()
|
||||
else:
|
||||
guessmime = mimedetect.from_buffer(data)
|
||||
|
||||
|
@ -186,14 +206,21 @@ def store_file(f, addr):
|
|||
if not ext:
|
||||
ext = ".bin"
|
||||
|
||||
with open(getpath(digest), "wb") as of:
|
||||
spath = getpath(digest)
|
||||
|
||||
with open(spath, "wb") as of:
|
||||
of.write(data)
|
||||
|
||||
sf = File(digest, ext, mime, addr)
|
||||
if app.config["NSFW_DETECT"]:
|
||||
nsfw_score = nsfw.detect(spath)
|
||||
else:
|
||||
nsfw_score = None
|
||||
|
||||
sf = File(digest, ext, mime, addr, nsfw_score)
|
||||
db.session.add(sf)
|
||||
db.session.commit()
|
||||
|
||||
return geturl(sf.getname())
|
||||
return sf.geturl()
|
||||
|
||||
def store_url(url, addr):
|
||||
if is_fhost_url(url):
|
||||
|
@ -438,6 +465,37 @@ def queryaddr(a):
|
|||
for f in res:
|
||||
query(su.enbase(f.id, 1))
|
||||
|
||||
def nsfw_detect(f):
|
||||
try:
|
||||
open(f["path"], 'r').close()
|
||||
f["nsfw_score"] = nsfw.detect(f["path"])
|
||||
return f
|
||||
except:
|
||||
return None
|
||||
|
||||
@manager.command
|
||||
def update_nsfw():
|
||||
if not app.config["NSFW_DETECT"]:
|
||||
print("NSFW detection is disabled in app config")
|
||||
return 1
|
||||
|
||||
from multiprocessing import Pool
|
||||
import tqdm
|
||||
|
||||
res = File.query.filter_by(nsfw_score=None, removed=False)
|
||||
|
||||
with Pool() as p:
|
||||
results = []
|
||||
work = [{ "path" : getpath(f.sha256), "id" : f.id} for f in res]
|
||||
|
||||
for r in tqdm.tqdm(p.imap_unordered(nsfw_detect, work), total=len(work)):
|
||||
if r:
|
||||
results.append({"id": r["id"], "nsfw_score" : r["nsfw_score"]})
|
||||
|
||||
db.session.bulk_update_mappings(File, results)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@manager.command
|
||||
def querybl():
|
||||
if os.path.isfile(app.config["FHOST_UPLOAD_BLACKLIST"]):
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
"""add NSFW score
|
||||
|
||||
Revision ID: 7e246705da6a
|
||||
Revises: 0cd36ecdd937
|
||||
Create Date: 2017-10-27 03:07:48.179290
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7e246705da6a'
|
||||
down_revision = '0cd36ecdd937'
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('file', sa.Column('nsfw_score', sa.Float(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('file', 'nsfw_score')
|
||||
# ### end Alembic commands ###
|
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from subprocess import run, PIPE, DEVNULL
|
||||
|
||||
os.environ["GLOG_minloglevel"] = "2" # seriously :|
|
||||
import caffe
|
||||
|
||||
class NSFWDetector:
|
||||
def __init__(self):
|
||||
|
||||
npath = os.path.join(os.path.dirname(__file__), "nsfw_model")
|
||||
self.nsfw_net = caffe.Net(os.path.join(npath, "deploy.prototxt"),
|
||||
os.path.join(npath, "resnet_50_1by2_nsfw.caffemodel"),
|
||||
caffe.TEST)
|
||||
self.caffe_transformer = caffe.io.Transformer({'data': self.nsfw_net.blobs['data'].data.shape})
|
||||
self.caffe_transformer.set_transpose('data', (2, 0, 1)) # move image channels to outermost
|
||||
self.caffe_transformer.set_mean('data', np.array([104, 117, 123])) # subtract the dataset-mean value in each channel
|
||||
self.caffe_transformer.set_raw_scale('data', 255) # rescale from [0, 1] to [0, 255]
|
||||
self.caffe_transformer.set_channel_swap('data', (2, 1, 0)) # swap channels from RGB to BGR
|
||||
|
||||
def _compute(self, img):
|
||||
image = caffe.io.load_image(BytesIO(img))
|
||||
|
||||
H, W, _ = image.shape
|
||||
_, _, h, w = self.nsfw_net.blobs["data"].data.shape
|
||||
h_off = int(max((H - h) / 2, 0))
|
||||
w_off = int(max((W - w) / 2, 0))
|
||||
crop = image[h_off:h_off + h, w_off:w_off + w, :]
|
||||
|
||||
transformed_image = self.caffe_transformer.preprocess('data', crop)
|
||||
transformed_image.shape = (1,) + transformed_image.shape
|
||||
|
||||
input_name = self.nsfw_net.inputs[0]
|
||||
output_layers = ["prob"]
|
||||
all_outputs = self.nsfw_net.forward_all(blobs=output_layers,
|
||||
**{input_name: transformed_image})
|
||||
|
||||
outputs = all_outputs[output_layers[0]][0].astype(float)
|
||||
|
||||
return outputs
|
||||
|
||||
def detect(self, fpath):
|
||||
try:
|
||||
ff = run(["ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", "-cpng", "-i", fpath], stdout=PIPE, stderr=DEVNULL, check=True)
|
||||
image_data = ff.stdout
|
||||
except:
|
||||
return -1.0
|
||||
|
||||
scores = self._compute(image_data)
|
||||
|
||||
return scores[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
n = NSFWDetector()
|
||||
|
||||
for inf in sys.argv[1:]:
|
||||
score = n.detect(inf)
|
||||
print(inf, score)
|
|
@ -0,0 +1,11 @@
|
|||
|
||||
Copyright 2016, Yahoo Inc.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Loading…
Reference in New Issue