Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a3dbba0d
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3dbba0d
编写于
4月 14, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add inference program clone and update README.md
上级
dbb29416
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
22 addition
and
10 deletion
+22
-10
demo/ernie-classification/README.md
demo/ernie-classification/README.md
+1
-1
demo/ernie-classification/cls_predict.py
demo/ernie-classification/cls_predict.py
+20
-8
demo/ernie-classification/run_predict.sh
demo/ernie-classification/run_predict.sh
+1
-1
未找到文件。
demo/ernie-classification/README.md
浏览文件 @
a3dbba0d
...
...
@@ -120,5 +120,5 @@ python -u cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
参数配置正确后,请执行脚本
`sh run_predict.sh`
,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考
`cls_predict.py`
```
text=
风扇确实够响的,尤其是到晚上周围安静下来。风扇频频开启,发热量有些惊人 label=0 predict=[0.99244046 0.00755955]
text=
键盘缝隙大进灰,装系统自己不会装~~屏幕有点窄玩游戏人物有点变形 label=0 predict=0
```
demo/ernie-classification/cls_predict.py
浏览文件 @
a3dbba0d
...
...
@@ -57,6 +57,11 @@ if __name__ == '__main__':
# Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need
feed_list
=
[
input_dict
[
"input_ids"
].
name
,
input_dict
[
"position_ids"
].
name
,
input_dict
[
"segment_ids"
].
name
,
input_dict
[
"input_mask"
].
name
,
label
.
name
]
# Define a classfication finetune task by PaddleHub's API
cls_task
=
hub
.
create_text_classification_task
(
...
...
@@ -65,19 +70,26 @@ if __name__ == '__main__':
# classificatin probability tensor
probs
=
cls_task
.
variable
(
"probs"
)
pred
=
fluid
.
layers
.
argmax
(
probs
,
axis
=
1
)
# load best model checkpoint
fluid
.
io
.
load_persistables
(
exe
,
args
.
checkpoint_dir
)
feed_list
=
[
input_dict
[
"input_ids"
].
name
,
input_dict
[
"position_ids"
].
name
,
input_dict
[
"segment_ids"
].
name
,
input_dict
[
"input_mask"
].
name
,
label
.
name
]
inference_program
=
program
.
clone
(
for_test
=
True
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
test_reader
=
reader
.
data_generator
(
phase
=
'test'
,
shuffle
=
False
)
test_examples
=
dataset
.
get_test_examples
()
total
=
0
correct
=
0
for
index
,
batch
in
enumerate
(
test_reader
()):
probs_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
probs
.
name
])
print
(
"%s
\t
predict=%s"
%
(
test_examples
[
index
],
probs_v
[
0
][
0
]))
pred_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
pred
.
name
],
program
=
inference_program
)
total
+=
1
if
(
pred_v
[
0
][
0
]
==
int
(
test_examples
[
index
].
label
)):
correct
+=
1
acc
=
1.0
*
correct
/
total
print
(
"%s
\t
predict=%s"
%
(
test_examples
[
index
],
pred_v
[
0
][
0
]))
print
(
"accuracy = %f"
%
acc
)
demo/ernie-classification/run_predict.sh
浏览文件 @
a3dbba0d
export
CUDA_VISIBLE_DEVICES
=
1
export
CUDA_VISIBLE_DEVICES
=
5
CKPT_DIR
=
"./ckpt_sentiment_cls/best_model"
python
-u
cls_predict.py
--checkpoint_dir
$CKPT_DIR
--max_seq_len
128
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录