未验证 提交 b2853121 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update porn_detection_cnn (#2102)

上级 96baa2b8
...@@ -177,8 +177,10 @@ ...@@ -177,8 +177,10 @@
大幅提升预测性能,同时简化接口使用 大幅提升预测性能,同时简化接口使用
- ```shell * 1.2.0
$ hub install porn_detection_cnn==1.1.0
```
移除 fluid api
- ```shell
$ hub install porn_detection_cnn==1.2.0
```
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import math import math
import os import os
import six
import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
from paddlehub.common.paddle_helper import get_variable_info from .processor import load_vocab
from paddlehub.module.module import moduleinfo, serving from .processor import postprocess
from paddlehub.reader import tokenization from .processor import preprocess
from paddlehub.compat.task import tokenization
from porn_detection_cnn.processor import load_vocab, preprocess, postprocess from paddlehub.module.module import moduleinfo
from paddlehub.module.module import serving
@moduleinfo( @moduleinfo(name="porn_detection_cnn",
name="porn_detection_cnn", version="1.2.0",
version="1.1.0",
summary="Baidu's open-source Porn Detection Model.", summary="Baidu's open-source Porn Detection Model.",
author="baidu-nlp", author="baidu-nlp",
author_email="", author_email="",
type="nlp/sentiment_analysis") type="nlp/sentiment_analysis")
class PornDetectionCNN(hub.NLPPredictionModule): class PornDetectionCNN(hub.NLPPredictionModule):
def _initialize(self):
def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -42,42 +35,6 @@ class PornDetectionCNN(hub.NLPPredictionModule): ...@@ -42,42 +35,6 @@ class PornDetectionCNN(hub.NLPPredictionModule):
self._set_config() self._set_config()
def context(self, trainable=False):
"""
Get the input ,output and program of the pretrained porn_detection_cnn
Args:
trainable(bool): whether fine-tune the pretrained parameters of porn_detection_cnn or not
Returns:
inputs(dict): the input variables of porn_detection_cnn (words)
outputs(dict): the output variables of porn_detection_cnn (the sentiment prediction results)
main_program(Program): the main_program of porn_detection_cnn with pretrained prameters
"""
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_target_names, fetch_targets = fluid.io.load_inference_model(
dirname=self.pretrained_model_path, executor=exe)
with open(self.param_file, 'r') as file:
params_list = file.readlines()
for param in params_list:
param = param.strip()
var = program.global_block().var(param)
var_info = get_variable_info(var)
program.global_block().create_parameter(
shape=var_info['shape'], dtype=var_info['dtype'], name=var_info['name'])
for param in program.global_block().iter_parameters():
param.trainable = trainable
for name, var in program.global_block().vars.items():
if name == feed_target_names[0]:
inputs = {"words": var}
# output of sencond layer from the end prediction layer (fc-softmax)
if name == "@HUB_porn_detection_cnn@layer_norm_1.tmp_2":
outputs = {"class_probs": fetch_targets[0], "sentence_feature": var}
return inputs, outputs, program
@serving @serving
def detection(self, texts=[], data={}, use_gpu=False, batch_size=1): def detection(self, texts=[], data={}, use_gpu=False, batch_size=1):
""" """
...@@ -135,26 +92,3 @@ class PornDetectionCNN(hub.NLPPredictionModule): ...@@ -135,26 +92,3 @@ class PornDetectionCNN(hub.NLPPredictionModule):
""" """
self.labels = {"porn": 1, "not_porn": 0} self.labels = {"porn": 1, "not_porn": 0}
return self.labels return self.labels
if __name__ == "__main__":
porn_detection_cnn = PornDetectionCNN()
test_text = ["黄片下载", "打击黄牛党"]
results = porn_detection_cnn.detection(texts=test_text, batch_size=9)
for index, text in enumerate(test_text):
results[index]["text"] = text
for index, result in enumerate(results):
if six.PY2:
print(json.dumps(results[index], encoding="utf8", ensure_ascii=False))
else:
print(results[index])
input_dict = {"text": test_text}
results = porn_detection_cnn.detection(data=input_dict)
for index, text in enumerate(test_text):
results[index]["text"] = text
for index, result in enumerate(results):
if six.PY2:
print(json.dumps(results[index], encoding="utf8", ensure_ascii=False))
else:
print(results[index])
# -*- coding: utf-8 -*-
import io import io
import numpy as np import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册