未验证 提交 b2853121 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update porn_detection_cnn (#2102)

上级 96baa2b8
......@@ -22,7 +22,7 @@
- ### 1、环境依赖
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
......@@ -41,45 +41,45 @@
- ```shell
$ hub run porn_detection_cnn --input_text "黄片下载"
```
- 或者
- ```shell
$ hub run porn_detection_cnn --input_file test.txt
```
- 其中test.txt存放待审查文本,每行仅放置一段待审核文本
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
- ```python
import paddlehub as hub
porn_detection_cnn = hub.Module(name="porn_detection_cnn")
test_text = ["黄片下载", "打击黄牛党"]
results = porn_detection_cnn.detection(texts=test_text, use_gpu=True, batch_size=1)
for index, text in enumerate(test_text):
results[index]["text"] = text
for index, result in enumerate(results):
print(results[index])
# 输出结果如下:
# {'text': '黄片下载', 'porn_detection_label': 1, 'porn_detection_key': 'porn', 'porn_probs': 0.9324, 'not_porn_probs': 0.0676}
# {'text': '打击黄牛党', 'porn_detection_label': 0, 'porn_detection_key': 'not_porn', 'porn_probs': 0.0004, 'not_porn_probs': 0.9996}
```
- ### 3、API
- ```python
def detection(texts=[], data={}, use_gpu=False, batch_size=1)
```
- porn_detection_cnn预测接口,鉴定输入句子是否包含色情文案
- **参数**
......@@ -145,7 +145,7 @@
- ```python
import requests
import json
# 待预测数据
text = ["黄片下载", "打击黄牛党"]
......@@ -177,8 +177,10 @@
大幅提升预测性能,同时简化接口使用
* 1.2.0
移除 fluid api
- ```shell
$ hub install porn_detection_cnn==1.1.0
$ hub install porn_detection_cnn==1.2.0
```
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import math
import os
import six
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.common.paddle_helper import get_variable_info
from paddlehub.module.module import moduleinfo, serving
from paddlehub.reader import tokenization
from porn_detection_cnn.processor import load_vocab, preprocess, postprocess
@moduleinfo(
name="porn_detection_cnn",
version="1.1.0",
summary="Baidu's open-source Porn Detection Model.",
author="baidu-nlp",
author_email="",
type="nlp/sentiment_analysis")
from .processor import load_vocab
from .processor import postprocess
from .processor import preprocess
from paddlehub.compat.task import tokenization
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import serving
@moduleinfo(name="porn_detection_cnn",
version="1.2.0",
summary="Baidu's open-source Porn Detection Model.",
author="baidu-nlp",
author_email="",
type="nlp/sentiment_analysis")
class PornDetectionCNN(hub.NLPPredictionModule):
def _initialize(self):
def __init__(self):
"""
initialize with the necessary elements
"""
......@@ -42,42 +35,6 @@ class PornDetectionCNN(hub.NLPPredictionModule):
self._set_config()
def context(self, trainable=False):
"""
Get the input ,output and program of the pretrained porn_detection_cnn
Args:
trainable(bool): whether fine-tune the pretrained parameters of porn_detection_cnn or not
Returns:
inputs(dict): the input variables of porn_detection_cnn (words)
outputs(dict): the output variables of porn_detection_cnn (the sentiment prediction results)
main_program(Program): the main_program of porn_detection_cnn with pretrained prameters
"""
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_target_names, fetch_targets = fluid.io.load_inference_model(
dirname=self.pretrained_model_path, executor=exe)
with open(self.param_file, 'r') as file:
params_list = file.readlines()
for param in params_list:
param = param.strip()
var = program.global_block().var(param)
var_info = get_variable_info(var)
program.global_block().create_parameter(
shape=var_info['shape'], dtype=var_info['dtype'], name=var_info['name'])
for param in program.global_block().iter_parameters():
param.trainable = trainable
for name, var in program.global_block().vars.items():
if name == feed_target_names[0]:
inputs = {"words": var}
# output of sencond layer from the end prediction layer (fc-softmax)
if name == "@HUB_porn_detection_cnn@layer_norm_1.tmp_2":
outputs = {"class_probs": fetch_targets[0], "sentence_feature": var}
return inputs, outputs, program
@serving
def detection(self, texts=[], data={}, use_gpu=False, batch_size=1):
"""
......@@ -135,26 +92,3 @@ class PornDetectionCNN(hub.NLPPredictionModule):
"""
self.labels = {"porn": 1, "not_porn": 0}
return self.labels
if __name__ == "__main__":
porn_detection_cnn = PornDetectionCNN()
test_text = ["黄片下载", "打击黄牛党"]
results = porn_detection_cnn.detection(texts=test_text, batch_size=9)
for index, text in enumerate(test_text):
results[index]["text"] = text
for index, result in enumerate(results):
if six.PY2:
print(json.dumps(results[index], encoding="utf8", ensure_ascii=False))
else:
print(results[index])
input_dict = {"text": test_text}
results = porn_detection_cnn.detection(data=input_dict)
for index, text in enumerate(test_text):
results[index]["text"] = text
for index, result in enumerate(results):
if six.PY2:
print(json.dumps(results[index], encoding="utf8", ensure_ascii=False))
else:
print(results[index])
# -*- coding: utf-8 -*-
import io
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册