Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
e35ff5ec
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e35ff5ec
编写于
7月 28, 2020
作者:
K
kinghuin
提交者:
GitHub
7月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unitest, config, fix ernie_gen bugs and add ernie_tiny_couplet (#782)
上级
879383e8
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
380 addition
and
17 deletion
+380
-17
hub_module/modules/text/text_generation/ernie_gen_couplet/module.py
.../modules/text/text_generation/ernie_gen_couplet/module.py
+8
-6
hub_module/modules/text/text_generation/ernie_gen_poetry/README.md
...e/modules/text/text_generation/ernie_gen_poetry/README.md
+3
-3
hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie_gen.py
...t_generation/ernie_gen_poetry/model/modeling_ernie_gen.py
+2
-2
hub_module/modules/text/text_generation/ernie_gen_poetry/module.py
...e/modules/text/text_generation/ernie_gen_poetry/module.py
+9
-6
hub_module/modules/text/text_generation/ernie_tiny_couplet/README.md
...modules/text/text_generation/ernie_tiny_couplet/README.md
+93
-0
hub_module/modules/text/text_generation/ernie_tiny_couplet/__init__.py
...dules/text/text_generation/ernie_tiny_couplet/__init__.py
+0
-0
hub_module/modules/text/text_generation/ernie_tiny_couplet/module.py
...modules/text/text_generation/ernie_tiny_couplet/module.py
+144
-0
hub_module/scripts/configs/ernie_gen_couplet.yml
hub_module/scripts/configs/ernie_gen_couplet.yml
+9
-0
hub_module/scripts/configs/ernie_gen_poetry.yml
hub_module/scripts/configs/ernie_gen_poetry.yml
+9
-0
hub_module/scripts/configs/ernie_tiny_couplet.yml
hub_module/scripts/configs/ernie_tiny_couplet.yml
+9
-0
hub_module/tests/unittests/test_ernie_gen_couplet.py
hub_module/tests/unittests/test_ernie_gen_couplet.py
+32
-0
hub_module/tests/unittests/test_ernie_gen_poetry.py
hub_module/tests/unittests/test_ernie_gen_poetry.py
+30
-0
hub_module/tests/unittests/test_ernie_tiny_couplet.py
hub_module/tests/unittests/test_ernie_tiny_couplet.py
+32
-0
未找到文件。
hub_module/modules/text/text_generation/ernie_gen_couplet/module.py
浏览文件 @
e35ff5ec
...
@@ -50,11 +50,13 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -50,11 +50,13 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path
=
os
.
path
.
join
(
self
.
directory
,
"assets"
)
assets_path
=
os
.
path
.
join
(
self
.
directory
,
"assets"
)
gen_checkpoint_path
=
os
.
path
.
join
(
assets_path
,
"ernie_gen_couplet"
)
gen_checkpoint_path
=
os
.
path
.
join
(
assets_path
,
"ernie_gen_couplet"
)
ernie_cfg_path
=
os
.
path
.
join
(
assets_path
,
'ernie_config.json'
)
ernie_cfg_path
=
os
.
path
.
join
(
assets_path
,
'ernie_config.json'
)
ernie_cfg
=
dict
(
json
.
loads
(
open
(
ernie_cfg_path
).
read
()))
with
open
(
ernie_cfg_path
)
as
ernie_cfg_file
:
ernie_cfg
=
dict
(
json
.
loads
(
ernie_cfg_file
.
read
()))
ernie_vocab_path
=
os
.
path
.
join
(
assets_path
,
'vocab.txt'
)
ernie_vocab_path
=
os
.
path
.
join
(
assets_path
,
'vocab.txt'
)
with
open
(
ernie_vocab_path
)
as
ernie_vocab_file
:
ernie_vocab
=
{
ernie_vocab
=
{
j
.
strip
().
split
(
'
\t
'
)[
0
]:
i
j
.
strip
().
split
(
'
\t
'
)[
0
]:
i
for
i
,
j
in
enumerate
(
open
(
ernie_vocab_path
)
.
readlines
())
for
i
,
j
in
enumerate
(
ernie_vocab_file
.
readlines
())
}
}
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
...
@@ -183,5 +185,5 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -183,5 +185,5 @@ class ErnieGen(hub.NLPPredictionModule):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
module
=
ErnieGen
()
module
=
ErnieGen
()
for
result
in
module
.
generate
([
'
人增福寿年增岁
'
,
'风吹云乱天垂泪'
],
beam_width
=
5
):
for
result
in
module
.
generate
([
'
上海自来水来自海上
'
,
'风吹云乱天垂泪'
],
beam_width
=
5
):
print
(
result
)
print
(
result
)
hub_module/modules/text/text_generation/ernie_gen_poetry/README.md
浏览文件 @
e35ff5ec
...
@@ -10,7 +10,7 @@ ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶
...
@@ -10,7 +10,7 @@ ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶
## 命令行预测
## 命令行预测
```
shell
```
shell
$
hub run ernie_gen_poetry
--input_text
=
"
宝积峰前露术香,使君行旆照晴阳
。"
--use_gpu
True
--beam_width
5
$
hub run ernie_gen_poetry
--input_text
=
"
昔年旅南服,始识王荆州
。"
--use_gpu
True
--beam_width
5
```
```
## API
## API
...
@@ -38,7 +38,7 @@ import paddlehub as hub
...
@@ -38,7 +38,7 @@ import paddlehub as hub
module
=
hub
.
Module
(
name
=
"ernie_gen_poetry"
)
module
=
hub
.
Module
(
name
=
"ernie_gen_poetry"
)
test_texts
=
[
"宝积峰前露术香,使君行旆照晴阳。"
]
test_texts
=
[
'昔年旅南服,始识王荆州。'
,
'高名出汉阴,禅阁跨香岑。'
]
results
=
module
.
genrate
(
texts
=
test_texts
,
use_gpu
=
True
,
beam_width
=
5
)
results
=
module
.
genrate
(
texts
=
test_texts
,
use_gpu
=
True
,
beam_width
=
5
)
for
result
in
results
:
for
result
in
results
:
print
(
result
)
print
(
result
)
...
@@ -69,7 +69,7 @@ import json
...
@@ -69,7 +69,7 @@ import json
# 发送HTTP请求
# 发送HTTP请求
data
=
{
'texts'
:[
"宝积峰前露术香,使君行旆照晴阳。"
],
data
=
{
'texts'
:[
'昔年旅南服,始识王荆州。'
,
'高名出汉阴,禅阁跨香岑。'
],
'use_gpu'
:
False
,
'beam_width'
:
5
}
'use_gpu'
:
False
,
'beam_width'
:
5
}
headers
=
{
"Content-type"
:
"application/json"
}
headers
=
{
"Content-type"
:
"application/json"
}
url
=
"http://127.0.0.1:8866/predict/ernie_gen_poetry"
url
=
"http://127.0.0.1:8866/predict/ernie_gen_poetry"
...
...
hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie_gen.py
浏览文件 @
e35ff5ec
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
import
paddle.fluid
as
F
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
import
paddle.fluid.layers
as
L
from
ernie_gen_
couplet
.model.modeling_ernie
import
ErnieModel
from
ernie_gen_
poetry
.model.modeling_ernie
import
ErnieModel
from
ernie_gen_
couplet
.model.modeling_ernie
import
_build_linear
,
_build_ln
,
append_name
from
ernie_gen_
poetry
.model.modeling_ernie
import
_build_linear
,
_build_ln
,
append_name
class
ErnieModelForGeneration
(
ErnieModel
):
class
ErnieModelForGeneration
(
ErnieModel
):
...
...
hub_module/modules/text/text_generation/ernie_gen_poetry/module.py
浏览文件 @
e35ff5ec
...
@@ -50,11 +50,13 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -50,11 +50,13 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path
=
os
.
path
.
join
(
self
.
directory
,
"assets"
)
assets_path
=
os
.
path
.
join
(
self
.
directory
,
"assets"
)
gen_checkpoint_path
=
os
.
path
.
join
(
assets_path
,
"ernie_gen_poetry"
)
gen_checkpoint_path
=
os
.
path
.
join
(
assets_path
,
"ernie_gen_poetry"
)
ernie_cfg_path
=
os
.
path
.
join
(
assets_path
,
'ernie_config.json'
)
ernie_cfg_path
=
os
.
path
.
join
(
assets_path
,
'ernie_config.json'
)
ernie_cfg
=
dict
(
json
.
loads
(
open
(
ernie_cfg_path
).
read
()))
with
open
(
ernie_cfg_path
)
as
ernie_cfg_file
:
ernie_cfg
=
dict
(
json
.
loads
(
ernie_cfg_file
.
read
()))
ernie_vocab_path
=
os
.
path
.
join
(
assets_path
,
'vocab.txt'
)
ernie_vocab_path
=
os
.
path
.
join
(
assets_path
,
'vocab.txt'
)
with
open
(
ernie_vocab_path
)
as
ernie_vocab_file
:
ernie_vocab
=
{
ernie_vocab
=
{
j
.
strip
().
split
(
'
\t
'
)[
0
]:
i
j
.
strip
().
split
(
'
\t
'
)[
0
]:
i
for
i
,
j
in
enumerate
(
open
(
ernie_vocab_path
)
.
readlines
())
for
i
,
j
in
enumerate
(
ernie_vocab_file
.
readlines
())
}
}
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
...
@@ -183,5 +185,6 @@ class ErnieGen(hub.NLPPredictionModule):
...
@@ -183,5 +185,6 @@ class ErnieGen(hub.NLPPredictionModule):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
module
=
ErnieGen
()
module
=
ErnieGen
()
for
result
in
module
.
generate
([
'宝积峰前露术香,使君行旆照晴阳。'
],
beam_width
=
5
):
for
result
in
module
.
generate
([
'昔年旅南服,始识王荆州。'
,
'高名出汉阴,禅阁跨香岑。'
],
beam_width
=
5
):
print
(
result
)
print
(
result
)
hub_module/modules/text/text_generation/ernie_tiny_couplet/README.md
0 → 100644
浏览文件 @
e35ff5ec
```
shell
$
hub
install
ernie_tiny_couplet
==
1.0.0
```
<p
align=
"center"
>
<img
src=
"https://paddlehub.bj.bcebos.com/paddlehub-img%2Fernie_tiny_framework.PNG"
hspace=
'10'
/>
<br
/>
</p>
本预测module系由TextGenerationTask微调而来,转换方式可以参考
[
Fine-tune保存的模型如何转化为一个PaddleHub Module
](
https://github.com/PaddlePaddle/PaddleHub/blob/develop/docs/tutorial/finetuned_model_to_module.md
)
。
## 命令行预测
```
shell
$
hub run ernie_tiny_couplet
--input_text
'风吹云乱天垂泪'
```
命令行预测只支持使用CPU预测,如需使用GPU,请使用API方式预测。
## API
```
python
def
generate
(
texts
)
```
对联预测接口,输入上联文本,输出下联文本。该接口封装了上联文本使用
`hub.BertTokenizer`
编码的过程,因此它的调用方式比demo中提供的
[
predcit接口
](
https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation/predict.py#L83
)
简单。
**参数**
> texts(list[str]): 上联文本。
**返回**
> result(list[str]): 下联文本。每个上联会对应输出10个下联。
**代码示例**
```
python
import
paddlehub
as
hub
# Load ernie pretrained model
module
=
hub
.
Module
(
name
=
"ernie_tiny_couplet"
)
results
=
module
.
generate
([
"风吹云乱天垂泪"
,
"若有经心风过耳"
])
for
result
in
results
:
print
(
result
)
```
## 服务部署
PaddleHub Serving 可以部署在线服务。
### 第一步:启动PaddleHub Serving
运行启动命令:
```
shell
$
hub serving start
-m
ernie_tiny_couplet
```
这样就完成了一个服务化API的部署,默认端口号为8866。
**NOTE:**
服务部署只支持使用CPU,如需使用GPU,请使用API方式预测。
### 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```
python
import
requests
import
json
# 发送HTTP请求
data
=
{
'texts'
:[
"风吹云乱天垂泪"
,
"若有经心风过耳"
]}
headers
=
{
"Content-type"
:
"application/json"
}
url
=
"http://127.0.0.1:8866/predict/ernie_tiny_couplet"
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
# 保存结果
results
=
r
.
json
()[
"results"
]
print
(
results
)
```
## 查看代码
https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation
## 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
## 更新历史
*
1.0.0
初始发布。
hub_module/modules/text/text_generation/ernie_tiny_couplet/__init__.py
0 → 100644
浏览文件 @
e35ff5ec
hub_module/modules/text/text_generation/ernie_tiny_couplet/module.py
0 → 100644
浏览文件 @
e35ff5ec
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
ast
import
argparse
import
paddlehub
as
hub
from
paddlehub.module.module
import
moduleinfo
,
serving
,
runnable
from
paddlehub.module.nlp_module
import
DataFormatError
@
moduleinfo
(
name
=
"ernie_tiny_couplet"
,
version
=
"1.0.0"
,
summary
=
"couplet generation model fine-tuned with ernie_tiny module"
,
author
=
"paddlehub"
,
author_email
=
""
,
type
=
"nlp/text_generation"
,
)
class
ErnieTinyCouplet
(
hub
.
NLPPredictionModule
):
def
_initialize
(
self
,
use_gpu
=
False
):
# Load Paddlehub ERNIE Tiny pretrained model
self
.
module
=
hub
.
Module
(
name
=
"ernie_tiny"
)
inputs
,
outputs
,
program
=
self
.
module
.
context
(
trainable
=
True
,
max_seq_len
=
128
)
# Download dataset and get its label list and label num
# If you just want labels information, you can omit its tokenizer parameter to avoid preprocessing the train set.
dataset
=
hub
.
dataset
.
Couplet
()
self
.
label_list
=
dataset
.
get_labels
()
# Setup RunConfig for PaddleHub Fine-tune API
config
=
hub
.
RunConfig
(
use_data_parallel
=
False
,
use_cuda
=
use_gpu
,
batch_size
=
1
,
checkpoint_dir
=
os
.
path
.
join
(
self
.
directory
,
"assets"
,
"ckpt"
),
strategy
=
hub
.
AdamWeightDecayStrategy
())
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
pooled_output
=
outputs
[
"pooled_output"
]
sequence_output
=
outputs
[
"sequence_output"
]
# Define a classfication fine-tune task by PaddleHub's API
self
.
gen_task
=
hub
.
TextGenerationTask
(
feature
=
pooled_output
,
token_feature
=
sequence_output
,
max_seq_len
=
128
,
num_classes
=
dataset
.
num_labels
,
config
=
config
,
metrics_choices
=
[
"bleu"
])
def
generate
(
self
,
texts
):
# Add 0x02 between characters to match the format of training data,
# otherwise the length of prediction results will not match the input string
# if the input string contains non-Chinese characters.
formatted_text_a
=
list
(
map
(
"
\002
"
.
join
,
texts
))
# Use the appropriate tokenizer to preprocess the data
# For ernie_tiny, it use BertTokenizer too.
tokenizer
=
hub
.
BertTokenizer
(
vocab_file
=
self
.
module
.
get_vocab_path
())
encoded_data
=
[
tokenizer
.
encode
(
text
=
text
,
max_seq_len
=
128
)
for
text
in
formatted_text_a
]
results
=
self
.
gen_task
.
generate
(
data
=
encoded_data
,
label_list
=
self
.
label_list
,
accelerate_mode
=
False
)
results
=
[[
""
.
join
(
sample_result
)
for
sample_result
in
sample_results
]
for
sample_results
in
results
]
return
results
def
add_module_config_arg
(
self
):
"""
Add the command config options
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
@
runnable
def
run_cmd
(
self
,
argvs
):
"""
Run as a command
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
name
,
prog
=
'hub run %s'
%
self
.
name
,
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
try
:
input_data
=
self
.
check_input_data
(
args
)
except
DataFormatError
and
RuntimeError
:
self
.
parser
.
print_help
()
return
None
results
=
self
.
generate
(
texts
=
input_data
)
return
results
@
serving
def
serving_method
(
self
,
texts
):
"""
Run as a service.
"""
results
=
self
.
generate
(
texts
)
return
results
if
__name__
==
'__main__'
:
module
=
ErnieTinyCouplet
()
results
=
module
.
generate
([
"风吹云乱天垂泪"
,
"若有经心风过耳"
])
for
result
in
results
:
print
(
result
)
hub_module/scripts/configs/ernie_gen_couplet.yml
0 → 100644
浏览文件 @
e35ff5ec
name
:
ernie_gen_couplet
dir
:
"
modules/text/text_generation/ernie_gen_couplet"
exclude
:
-
README.md
resources
:
-
url
:
https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_couplet/assets.tar.gz
dest
:
assets
uncompress
:
True
hub_module/scripts/configs/ernie_gen_poetry.yml
0 → 100644
浏览文件 @
e35ff5ec
name
:
ernie_gen_poetry
dir
:
"
modules/text/text_generation/ernie_gen_poetry"
exclude
:
-
README.md
resources
:
-
url
:
https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_poetry/assets.tar.gz
dest
:
assets
uncompress
:
True
hub_module/scripts/configs/ernie_tiny_couplet.yml
0 → 100644
浏览文件 @
e35ff5ec
name
:
ernie_tiny_couplet
dir
:
"
modules/text/text_generation/ernie_tiny_couplet"
exclude
:
-
README.md
resources
:
-
url
:
https://paddlehub.bj.bcebos.com/model/nlp/ernie_tiny_couplet/assets.tar.gz
dest
:
assets
uncompress
:
True
hub_module/tests/unittests/test_ernie_gen_couplet.py
0 → 100644
浏览文件 @
e35ff5ec
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
unittest
import
TestCase
,
main
import
paddlehub
as
hub
class
ErnieGenCoupletTestCase
(
TestCase
):
def
setUp
(
self
):
self
.
module
=
hub
.
Module
(
name
=
'ernie_gen_couplet'
)
self
.
left
=
[
"风吹云乱天垂泪"
,
"若有经心风过耳"
]
def
test_predict
(
self
):
rights
=
self
.
module
.
generate
(
self
.
left
)
self
.
assertEqual
(
len
(
rights
),
2
)
self
.
assertEqual
(
len
(
rights
[
0
]),
5
)
self
.
assertEqual
(
len
(
rights
[
0
][
0
]),
7
)
self
.
assertEqual
(
len
(
rights
[
1
][
0
]),
7
)
if
__name__
==
'__main__'
:
main
()
hub_module/tests/unittests/test_ernie_gen_poetry.py
0 → 100644
浏览文件 @
e35ff5ec
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
unittest
import
TestCase
,
main
import
paddlehub
as
hub
class
ErnieGenPoetryTestCase
(
TestCase
):
def
setUp
(
self
):
self
.
module
=
hub
.
Module
(
name
=
'ernie_gen_poetry'
)
self
.
left
=
[
"昔年旅南服,始识王荆州。"
,
"高名出汉阴,禅阁跨香岑。"
]
def
test_predict
(
self
):
rights
=
self
.
module
.
generate
(
self
.
left
)
self
.
assertEqual
(
len
(
rights
),
2
)
self
.
assertEqual
(
len
(
rights
[
0
]),
5
)
if
__name__
==
'__main__'
:
main
()
hub_module/tests/unittests/test_ernie_tiny_couplet.py
0 → 100644
浏览文件 @
e35ff5ec
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
unittest
import
TestCase
,
main
import
paddlehub
as
hub
class
ErnieTinyCoupletTestCase
(
TestCase
):
def
setUp
(
self
):
self
.
module
=
hub
.
Module
(
name
=
'ernie_tiny_couplet'
)
self
.
left
=
[
"风吹云乱天垂泪"
,
"若有经心风过耳"
]
def
test_predict
(
self
):
rights
=
self
.
module
.
predict
(
self
.
left
)
self
.
assertEqual
(
len
(
rights
),
2
)
self
.
assertEqual
(
len
(
rights
[
0
]),
10
)
self
.
assertEqual
(
len
(
rights
[
0
][
0
]),
7
)
self
.
assertEqual
(
len
(
rights
[
1
][
0
]),
7
)
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录