未验证 提交 e35ff5ec 编写于 作者: K kinghuin 提交者: GitHub

add unitest, config, fix ernie_gen bugs and add ernie_tiny_couplet (#782)

上级 879383e8
......@@ -50,12 +50,14 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
ernie_cfg = dict(json.loads(open(ernie_cfg_path).read()))
with open(ernie_cfg_path) as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(open(ernie_vocab_path).readlines())
}
with open(ernie_vocab_path) as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
}
with fluid.dygraph.guard(fluid.CPUPlace()):
with fluid.unique_name.guard():
......@@ -183,5 +185,5 @@ class ErnieGen(hub.NLPPredictionModule):
if __name__ == "__main__":
module = ErnieGen()
for result in module.generate(['人增福寿年增岁', '风吹云乱天垂泪'], beam_width=5):
for result in module.generate(['上海自来水来自海上', '风吹云乱天垂泪'], beam_width=5):
print(result)
......@@ -10,7 +10,7 @@ ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶
## 命令行预测
```shell
$ hub run ernie_gen_poetry --input_text="宝积峰前露术香,使君行旆照晴阳。" --use_gpu True --beam_width 5
$ hub run ernie_gen_poetry --input_text="昔年旅南服,始识王荆州。" --use_gpu True --beam_width 5
```
## API
......@@ -38,7 +38,7 @@ import paddlehub as hub
module = hub.Module(name="ernie_gen_poetry")
test_texts = ["宝积峰前露术香,使君行旆照晴阳。"]
test_texts = ['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。']
results = module.genrate(texts=test_texts, use_gpu=True, beam_width=5)
for result in results:
print(result)
......@@ -69,7 +69,7 @@ import json
# 发送HTTP请求
data = {'texts':["宝积峰前露术香,使君行旆照晴阳。"],
data = {'texts':['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。'],
'use_gpu':False, 'beam_width':5}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ernie_gen_poetry"
......
......@@ -15,8 +15,8 @@
import paddle.fluid as F
import paddle.fluid.layers as L
from ernie_gen_couplet.model.modeling_ernie import ErnieModel
from ernie_gen_couplet.model.modeling_ernie import _build_linear, _build_ln, append_name
from ernie_gen_poetry.model.modeling_ernie import ErnieModel
from ernie_gen_poetry.model.modeling_ernie import _build_linear, _build_ln, append_name
class ErnieModelForGeneration(ErnieModel):
......
......@@ -50,12 +50,14 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
ernie_cfg = dict(json.loads(open(ernie_cfg_path).read()))
with open(ernie_cfg_path) as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(open(ernie_vocab_path).readlines())
}
with open(ernie_vocab_path) as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
}
with fluid.dygraph.guard(fluid.CPUPlace()):
with fluid.unique_name.guard():
......@@ -183,5 +185,6 @@ class ErnieGen(hub.NLPPredictionModule):
if __name__ == "__main__":
module = ErnieGen()
for result in module.generate(['宝积峰前露术香,使君行旆照晴阳。'], beam_width=5):
for result in module.generate(['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。'],
beam_width=5):
print(result)
```shell
$ hub install ernie_tiny_couplet==1.0.0
```
<p align="center">
<img src="https://paddlehub.bj.bcebos.com/paddlehub-img%2Fernie_tiny_framework.PNG" hspace='10'/> <br />
</p>
本预测module系由TextGenerationTask微调而来,转换方式可以参考[Fine-tune保存的模型如何转化为一个PaddleHub Module](https://github.com/PaddlePaddle/PaddleHub/blob/develop/docs/tutorial/finetuned_model_to_module.md)
## 命令行预测
```shell
$ hub run ernie_tiny_couplet --input_text '风吹云乱天垂泪'
```
命令行预测只支持使用CPU预测,如需使用GPU,请使用API方式预测。
## API
```python
def generate(texts)
```
对联预测接口,输入上联文本,输出下联文本。该接口封装了上联文本使用`hub.BertTokenizer`编码的过程,因此它的调用方式比demo中提供的[predcit接口](https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation/predict.py#L83)简单。
**参数**
> texts(list[str]): 上联文本。
**返回**
> result(list[str]): 下联文本。每个上联会对应输出10个下联。
**代码示例**
```python
import paddlehub as hub
# Load ernie pretrained model
module = hub.Module(name="ernie_tiny_couplet")
results = module.generate(["风吹云乱天垂泪", "若有经心风过耳"])
for result in results:
print(result)
```
## 服务部署
PaddleHub Serving 可以部署在线服务。
### 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m ernie_tiny_couplet
```
这样就完成了一个服务化API的部署,默认端口号为8866。
**NOTE:** 服务部署只支持使用CPU,如需使用GPU,请使用API方式预测。
### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
# 发送HTTP请求
data = {'texts':["风吹云乱天垂泪", "若有经心风过耳"]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ernie_tiny_couplet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 保存结果
results = r.json()["results"]
print(results)
```
## 查看代码
https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation
## 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
## 更新历史
* 1.0.0
初始发布。
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ast
import argparse
import paddlehub as hub
from paddlehub.module.module import moduleinfo, serving, runnable
from paddlehub.module.nlp_module import DataFormatError
@moduleinfo(
name="ernie_tiny_couplet",
version="1.0.0",
summary="couplet generation model fine-tuned with ernie_tiny module",
author="paddlehub",
author_email="",
type="nlp/text_generation",
)
class ErnieTinyCouplet(hub.NLPPredictionModule):
def _initialize(self, use_gpu=False):
# Load Paddlehub ERNIE Tiny pretrained model
self.module = hub.Module(name="ernie_tiny")
inputs, outputs, program = self.module.context(
trainable=True, max_seq_len=128)
# Download dataset and get its label list and label num
# If you just want labels information, you can omit its tokenizer parameter to avoid preprocessing the train set.
dataset = hub.dataset.Couplet()
self.label_list = dataset.get_labels()
# Setup RunConfig for PaddleHub Fine-tune API
config = hub.RunConfig(
use_data_parallel=False,
use_cuda=use_gpu,
batch_size=1,
checkpoint_dir=os.path.join(self.directory, "assets", "ckpt"),
strategy=hub.AdamWeightDecayStrategy())
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
pooled_output = outputs["pooled_output"]
sequence_output = outputs["sequence_output"]
# Define a classfication fine-tune task by PaddleHub's API
self.gen_task = hub.TextGenerationTask(
feature=pooled_output,
token_feature=sequence_output,
max_seq_len=128,
num_classes=dataset.num_labels,
config=config,
metrics_choices=["bleu"])
def generate(self, texts):
# Add 0x02 between characters to match the format of training data,
# otherwise the length of prediction results will not match the input string
# if the input string contains non-Chinese characters.
formatted_text_a = list(map("\002".join, texts))
# Use the appropriate tokenizer to preprocess the data
# For ernie_tiny, it use BertTokenizer too.
tokenizer = hub.BertTokenizer(vocab_file=self.module.get_vocab_path())
encoded_data = [
tokenizer.encode(text=text, max_seq_len=128)
for text in formatted_text_a
]
results = self.gen_task.generate(
data=encoded_data,
label_list=self.label_list,
accelerate_mode=False)
results = [["".join(sample_result) for sample_result in sample_results]
for sample_results in results]
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU for prediction")
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description='Run the %s module.' % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
try:
input_data = self.check_input_data(args)
except DataFormatError and RuntimeError:
self.parser.print_help()
return None
results = self.generate(texts=input_data)
return results
@serving
def serving_method(self, texts):
"""
Run as a service.
"""
results = self.generate(texts)
return results
if __name__ == '__main__':
module = ErnieTinyCouplet()
results = module.generate(["风吹云乱天垂泪", "若有经心风过耳"])
for result in results:
print(result)
name: ernie_gen_couplet
dir: "modules/text/text_generation/ernie_gen_couplet"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_couplet/assets.tar.gz
dest: assets
uncompress: True
name: ernie_gen_poetry
dir: "modules/text/text_generation/ernie_gen_poetry"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_poetry/assets.tar.gz
dest: assets
uncompress: True
name: ernie_tiny_couplet
dir: "modules/text/text_generation/ernie_tiny_couplet"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_tiny_couplet/assets.tar.gz
dest: assets
uncompress: True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, main
import paddlehub as hub
class ErnieGenCoupletTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='ernie_gen_couplet')
self.left = ["风吹云乱天垂泪", "若有经心风过耳"]
def test_predict(self):
rights = self.module.generate(self.left)
self.assertEqual(len(rights), 2)
self.assertEqual(len(rights[0]), 5)
self.assertEqual(len(rights[0][0]), 7)
self.assertEqual(len(rights[1][0]), 7)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, main
import paddlehub as hub
class ErnieGenPoetryTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='ernie_gen_poetry')
self.left = ["昔年旅南服,始识王荆州。", "高名出汉阴,禅阁跨香岑。"]
def test_predict(self):
rights = self.module.generate(self.left)
self.assertEqual(len(rights), 2)
self.assertEqual(len(rights[0]), 5)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, main
import paddlehub as hub
class ErnieTinyCoupletTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='ernie_tiny_couplet')
self.left = ["风吹云乱天垂泪", "若有经心风过耳"]
def test_predict(self):
rights = self.module.predict(self.left)
self.assertEqual(len(rights), 2)
self.assertEqual(len(rights[0]), 10)
self.assertEqual(len(rights[0][0]), 7)
self.assertEqual(len(rights[1][0]), 7)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册