Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b67ce353
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b67ce353
编写于
5月 16, 2018
作者:
Y
Yang Yang(Tony)
提交者:
GitHub
5月 16, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
speed up test label semantic roles (#10718)
上级
14248a64
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
21 deletion
+6
-21
python/paddle/fluid/tests/book/test_label_semantic_roles.py
python/paddle/fluid/tests/book/test_label_semantic_roles.py
+6
-21
未找到文件。
python/paddle/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
b67ce353
...
...
@@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
crf_decode
=
fluid
.
layers
.
crf_decoding
(
input
=
feature_out
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'crfw'
))
chunk_evaluator
=
fluid
.
evaluator
.
ChunkEvaluator
(
input
=
crf_decode
,
label
=
target
,
chunk_scheme
=
"IOB"
,
num_chunk_types
=
int
(
math
.
ceil
((
label_dict_len
-
1
)
/
2.0
)))
train_data
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
conll05
.
test
(),
buf_size
=
8192
),
...
...
@@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True):
def
train_loop
(
main_program
):
exe
.
run
(
fluid
.
default_startup_program
())
embedding_param
=
fluid
.
global_scope
().
find_var
(
embedding_name
).
get_tensor
()
embedding_param
.
set
(
...
...
@@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True):
start_time
=
time
.
time
()
batch_id
=
0
for
pass_id
in
xrange
(
PASS_NUM
):
chunk_evaluator
.
reset
(
exe
)
for
data
in
train_data
():
cost
,
precision
,
recall
,
f1_score
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
]
+
chunk_evaluator
.
metrics
)
pass_precision
,
pass_recall
,
pass_f1_score
=
chunk_evaluator
.
eval
(
exe
)
cost
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
])
cost
=
cost
[
0
]
if
batch_id
%
10
==
0
:
print
(
"avg_cost:"
+
str
(
cost
)
+
" precision:"
+
str
(
precision
)
+
" recall:"
+
str
(
recall
)
+
" f1_score:"
+
str
(
f1_score
)
+
" pass_precision:"
+
str
(
pass_precision
)
+
" pass_recall:"
+
str
(
pass_recall
)
+
" pass_f1_score:"
+
str
(
pass_f1_score
))
print
(
"avg_cost:"
+
str
(
cost
))
if
batch_id
!=
0
:
print
(
"second per batch: "
+
str
((
time
.
time
(
)
-
start_time
)
/
batch_id
))
# Set the threshold low to speed up the CI test
if
float
(
pass_precision
)
>
0.01
:
if
float
(
cost
)
<
60.0
:
if
save_dirname
is
not
None
:
# TODO(liuyiqun): Change the target to crf_decode
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录