Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
673a3cf7
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
673a3cf7
编写于
6月 15, 2021
作者:
骑
骑马小猫
提交者:
GitHub
6月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add return_prob in text classification module
上级
9c1fb388
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
27 addition
and
10 deletion
+27
-10
README.md
README.md
+9
-2
demo/text_classification/predict.py
demo/text_classification/predict.py
+1
-1
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+17
-7
未找到文件。
README.md
浏览文件 @
673a3cf7
...
...
@@ -124,8 +124,15 @@ please add WeChat above and send "Hub" to the robot, the robot will invite you t
## QuickStart
```
python
!
pip
install
--
upgrade
paddlepaddle
!
pip
install
--
upgrade
paddlehub
# install paddlepaddle with gpu
# !pip install --upgrade paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
# or install paddlepaddle with cpu
!
pip
install
--
upgrade
paddlepaddle
-
i
https
:
//
mirror
.
baidu
.
com
/
pypi
/
simple
# install paddlehub
!
pip
install
--
upgrade
paddlehub
-
i
https
:
//
mirror
.
baidu
.
com
/
pypi
/
simple
import
paddlehub
as
hub
...
...
demo/text_classification/predict.py
浏览文件 @
673a3cf7
...
...
@@ -28,6 +28,6 @@ if __name__ == '__main__':
task
=
'seq-cls'
,
load_checkpoint
=
'./test_ernie_text_cls/best_model/model.pdparams'
,
label_map
=
label_map
)
results
=
model
.
predict
(
data
,
max_seq_len
=
50
,
batch_size
=
1
,
use_gpu
=
Fals
e
)
results
,
probs
=
model
.
predict
(
data
,
max_seq_len
=
50
,
batch_size
=
1
,
use_gpu
=
False
,
return_prob
=
Tru
e
)
for
idx
,
text
in
enumerate
(
data
):
print
(
'Data: {}
\t
Lable: {}'
.
format
(
text
[
0
],
results
[
idx
]))
paddlehub/module/nlp_module.py
浏览文件 @
673a3cf7
...
...
@@ -18,7 +18,7 @@ import io
import
json
import
os
import
six
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Union
import
paddle
import
paddle.nn
as
nn
...
...
@@ -552,7 +552,8 @@ class TransformerModule(RunModule, TextServing):
max_seq_len
:
int
=
128
,
split_char
:
str
=
'
\002
'
,
batch_size
:
int
=
1
,
use_gpu
:
bool
=
False
):
use_gpu
:
bool
=
False
,
return_prob
:
bool
=
False
):
"""
Predicts the data labels.
...
...
@@ -563,6 +564,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '
\002
'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns:
results(obj:`list`): All the predictions labels.
...
...
@@ -579,6 +581,8 @@ class TransformerModule(RunModule, TextServing):
batches
=
self
.
_batchify
(
data
,
max_seq_len
,
batch_size
,
split_char
)
results
=
[]
batch_probs
=
[]
self
.
eval
()
for
batch
in
batches
:
if
self
.
task
==
'text-matching'
:
...
...
@@ -589,32 +593,38 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids
=
paddle
.
to_tensor
(
title_segment_ids
)
probs
=
self
(
query_input_ids
=
query_input_ids
,
query_token_type_ids
=
query_segment_ids
,
\
title_input_ids
=
title_input_ids
,
title_token_type_ids
=
title_segment_ids
)
idx
=
paddle
.
argmax
(
probs
,
axis
=
1
).
numpy
()
idx
=
idx
.
tolist
()
labels
=
[
self
.
label_map
[
i
]
for
i
in
idx
]
results
.
extend
(
labels
)
else
:
input_ids
,
segment_ids
=
batch
input_ids
=
paddle
.
to_tensor
(
input_ids
)
segment_ids
=
paddle
.
to_tensor
(
segment_ids
)
if
self
.
task
==
'seq-cls'
:
probs
=
self
(
input_ids
,
segment_ids
)
idx
=
paddle
.
argmax
(
probs
,
axis
=
1
).
numpy
()
idx
=
idx
.
tolist
()
labels
=
[
self
.
label_map
[
i
]
for
i
in
idx
]
results
.
extend
(
labels
)
elif
self
.
task
==
'token-cls'
:
probs
=
self
(
input_ids
,
segment_ids
)
batch_ids
=
paddle
.
argmax
(
probs
,
axis
=
2
).
numpy
()
# (batch_size, max_seq_len)
batch_ids
=
batch_ids
.
tolist
()
token_labels
=
[[
self
.
label_map
[
i
]
for
i
in
token_ids
]
for
token_ids
in
batch_ids
]
results
.
extend
(
token_labels
)
# token labels
labels
=
[[
self
.
label_map
[
i
]
for
i
in
token_ids
]
for
token_ids
in
batch_ids
]
elif
self
.
task
==
None
:
sequence_output
,
pooled_output
=
self
(
input_ids
,
segment_ids
)
results
.
append
(
[
pooled_output
.
squeeze
(
0
).
numpy
().
tolist
(),
sequence_output
.
squeeze
(
0
).
numpy
().
tolist
()])
if
self
.
task
:
# save probs only when return prob
if
return_prob
:
batch_probs
.
extend
(
probs
.
numpy
().
tolist
())
results
.
extend
(
labels
)
if
self
.
task
and
return_prob
:
return
results
,
batch_probs
return
results
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录