Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
660489ac
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
660489ac
编写于
4月 07, 2020
作者:
L
liym27
提交者:
GitHub
4月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add log and check predicted scores. test=develop (#23506)
上级
9bc223c8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
55 addition
and
14 deletion
+55
-14
python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py
...uid/tests/unittests/dygraph_to_static/test_transformer.py
+55
-14
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py
浏览文件 @
660489ac
...
...
@@ -29,6 +29,7 @@ trainer_count = 1
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
(
)
SEED
=
10
step_num
=
10
def
train_static
(
args
,
batch_generator
):
...
...
@@ -117,7 +118,7 @@ def train_static(args, batch_generator):
batch_id
+=
1
step_idx
+=
1
total_batch_num
=
total_batch_num
+
1
if
step_idx
==
10
:
if
step_idx
==
step_num
:
if
args
.
save_dygraph_model_path
:
model_path
=
os
.
path
.
join
(
args
.
save_static_model_path
,
"transformer"
)
...
...
@@ -201,7 +202,7 @@ def train_dygraph(args, batch_generator):
avg_batch_time
=
time
.
time
()
batch_id
+=
1
step_idx
+=
1
if
step_idx
==
10
:
if
step_idx
==
step_num
:
if
args
.
save_dygraph_model_path
:
model_dir
=
os
.
path
.
join
(
args
.
save_dygraph_model_path
)
if
not
os
.
path
.
exists
(
model_dir
):
...
...
@@ -250,10 +251,11 @@ def predict_dygraph(args, batch_generator):
transformer
.
eval
()
step_idx
=
0
speed_list
=
[]
for
input_data
in
test_loader
():
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
)
=
input_data
finished_seq
,
finished
_scores
=
transformer
.
beam_search
(
seq_ids
,
seq
_scores
=
transformer
.
beam_search
(
src_word
,
src_pos
,
src_slf_attn_bias
,
...
...
@@ -263,12 +265,28 @@ def predict_dygraph(args, batch_generator):
eos_id
=
args
.
eos_idx
,
beam_size
=
args
.
beam_size
,
max_len
=
args
.
max_out_len
)
finished_seq
=
finished_seq
.
numpy
()
finished_scores
=
finished_scores
.
numpy
()
seq_ids
=
seq_ids
.
numpy
()
seq_scores
=
seq_scores
.
numpy
()
if
step_idx
%
args
.
print_step
==
0
:
if
step_idx
==
0
:
logging
.
info
(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f"
%
(
step_idx
,
seq_ids
[
0
][
0
][
0
],
seq_scores
[
0
][
0
]))
avg_batch_time
=
time
.
time
()
else
:
speed
=
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)
speed_list
.
append
(
speed
)
logging
.
info
(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
%
(
step_idx
,
seq_ids
[
0
][
0
][
0
],
seq_scores
[
0
][
0
],
speed
))
avg_batch_time
=
time
.
time
()
step_idx
+=
1
if
step_idx
==
10
:
if
step_idx
==
step_num
:
break
return
finished_seq
logging
.
info
(
"Dygraph Predict: avg_speed: %.4f step/s"
%
(
np
.
mean
(
speed_list
)))
return
seq_ids
,
seq_scores
def
predict_static
(
args
,
batch_generator
):
...
...
@@ -318,16 +336,34 @@ def predict_static(args, batch_generator):
loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
step_idx
=
0
speed_list
=
[]
for
feed_dict
in
loader
:
seq_ids
,
seq_scores
=
exe
.
run
(
test_prog
,
feed
=
feed_dict
,
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
return_numpy
=
True
)
if
step_idx
%
args
.
print_step
==
0
:
if
step_idx
==
0
:
logging
.
info
(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f,"
%
(
step_idx
,
seq_ids
[
0
][
0
][
0
],
seq_scores
[
0
][
0
]))
avg_batch_time
=
time
.
time
()
else
:
speed
=
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)
speed_list
.
append
(
speed
)
logging
.
info
(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
%
(
step_idx
,
seq_ids
[
0
][
0
][
0
],
seq_scores
[
0
][
0
],
speed
))
avg_batch_time
=
time
.
time
()
step_idx
+=
1
if
step_idx
==
10
:
if
step_idx
==
step_num
:
break
return
seq_ids
logging
.
info
(
"Static Predict: avg_speed: %.4f step/s"
%
(
np
.
mean
(
speed_list
)))
return
seq_ids
,
seq_scores
class
TestTransformer
(
unittest
.
TestCase
):
...
...
@@ -344,12 +380,17 @@ class TestTransformer(unittest.TestCase):
def
_test_predict
(
self
):
args
,
batch_generator
=
self
.
prepare
(
mode
=
'test'
)
static_res
=
predict_static
(
args
,
batch_generator
)
dygraph_res
=
predict_dygraph
(
args
,
batch_generator
)
static_seq_ids
,
static_scores
=
predict_static
(
args
,
batch_generator
)
dygraph_seq_ids
,
dygraph_scores
=
predict_dygraph
(
args
,
batch_generator
)
self
.
assertTrue
(
np
.
allclose
(
static_seq_ids
,
static_seq_ids
),
msg
=
"static_seq_ids: {}
\n
dygraph_seq_ids: {}"
.
format
(
static_seq_ids
,
dygraph_seq_ids
))
self
.
assertTrue
(
np
.
allclose
(
static_
res
,
dygraph_
res
),
msg
=
"static_
res: {}
\n
dygraph_res: {}"
.
format
(
static_res
,
dygraph_
res
))
np
.
allclose
(
static_
scores
,
dygraph_sco
res
),
msg
=
"static_
scores: {}
\n
dygraph_scores: {}"
.
format
(
static_scores
,
dygraph_sco
res
))
def
test_check_result
(
self
):
self
.
_test_train
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录