Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
10d572a7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
10d572a7
编写于
7月 14, 2020
作者:
L
liym27
提交者:
GitHub
7月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2stat] Add Seq2Seq Attention model as ProgramTranslator Unit Test (#25422)
上级
25029254
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
377 addition
and
51 deletion
+377
-51
python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py
...ests/unittests/dygraph_to_static/seq2seq_dygraph_model.py
+292
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py
.../fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py
+4
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py
...e/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py
+81
-47
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py
浏览文件 @
10d572a7
...
...
@@ -329,8 +329,8 @@ class BaseModel(fluid.dygraph.Layer):
# beam search
batch_beam_shape
=
(
self
.
batch_size
,
self
.
beam_size
)
vocab_size_tensor
=
to_variable
(
np
.
full
((
1
),
self
.
tar_vocab_size
).
astype
(
"int64"
)
)
vocab_size_tensor
=
to_variable
(
np
.
full
((
1
),
self
.
tar_vocab_size
)).
astype
(
"int64"
)
start_token_tensor
=
to_variable
(
np
.
full
(
batch_beam_shape
,
self
.
beam_start_token
,
dtype
=
'int64'
))
...
...
@@ -448,3 +448,293 @@ class BaseModel(fluid.dygraph.Layer):
predicted_ids
=
fluid
.
layers
.
gather_tree
(
predicted_ids
,
parent_ids
)
predicted_ids
=
self
.
_transpose_batch_time
(
predicted_ids
)
return
predicted_ids
class
AttentionModel
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
hidden_size
,
src_vocab_size
,
tar_vocab_size
,
batch_size
,
num_layers
=
1
,
init_scale
=
0.1
,
dropout
=
None
,
beam_size
=
1
,
beam_start_token
=
1
,
beam_end_token
=
2
,
beam_max_step_num
=
2
,
mode
=
'train'
):
super
(
AttentionModel
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
src_vocab_size
=
src_vocab_size
self
.
tar_vocab_size
=
tar_vocab_size
self
.
batch_size
=
batch_size
self
.
num_layers
=
num_layers
self
.
init_scale
=
init_scale
self
.
dropout
=
dropout
self
.
beam_size
=
beam_size
self
.
beam_start_token
=
beam_start_token
self
.
beam_end_token
=
beam_end_token
self
.
beam_max_step_num
=
beam_max_step_num
self
.
mode
=
mode
self
.
kinf
=
1e9
param_attr
=
ParamAttr
(
initializer
=
uniform_initializer
(
self
.
init_scale
))
bias_attr
=
ParamAttr
(
initializer
=
zero_constant
)
forget_bias
=
1.0
self
.
src_embeder
=
Embedding
(
size
=
[
self
.
src_vocab_size
,
self
.
hidden_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'source_embedding'
,
initializer
=
uniform_initializer
(
init_scale
)))
self
.
tar_embeder
=
Embedding
(
size
=
[
self
.
tar_vocab_size
,
self
.
hidden_size
],
is_sparse
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'target_embedding'
,
initializer
=
uniform_initializer
(
init_scale
)))
self
.
enc_units
=
[]
for
i
in
range
(
num_layers
):
self
.
enc_units
.
append
(
self
.
add_sublayer
(
"enc_units_%d"
%
i
,
BasicLSTMUnit
(
hidden_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
forget_bias
=
forget_bias
)))
self
.
dec_units
=
[]
for
i
in
range
(
num_layers
):
if
i
==
0
:
self
.
dec_units
.
append
(
self
.
add_sublayer
(
"dec_units_%d"
%
i
,
BasicLSTMUnit
(
hidden_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
*
2
,
param_attr
=
ParamAttr
(
name
=
"dec_units_%d"
%
i
,
initializer
=
uniform_initializer
(
self
.
init_scale
)),
bias_attr
=
bias_attr
,
forget_bias
=
forget_bias
)))
else
:
self
.
dec_units
.
append
(
self
.
add_sublayer
(
"dec_units_%d"
%
i
,
BasicLSTMUnit
(
hidden_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
param_attr
=
ParamAttr
(
name
=
"dec_units_%d"
%
i
,
initializer
=
uniform_initializer
(
self
.
init_scale
)),
bias_attr
=
bias_attr
,
forget_bias
=
forget_bias
)))
self
.
attn_fc
=
fluid
.
dygraph
.
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
,
param_attr
=
ParamAttr
(
name
=
"self_attn_fc"
,
initializer
=
uniform_initializer
(
self
.
init_scale
)),
bias_attr
=
False
)
self
.
concat_fc
=
fluid
.
dygraph
.
nn
.
Linear
(
2
*
self
.
hidden_size
,
self
.
hidden_size
,
param_attr
=
ParamAttr
(
name
=
"self_concat_fc"
,
initializer
=
uniform_initializer
(
self
.
init_scale
)),
bias_attr
=
False
)
self
.
fc
=
fluid
.
dygraph
.
nn
.
Linear
(
self
.
hidden_size
,
self
.
tar_vocab_size
,
param_attr
=
ParamAttr
(
name
=
"self_fc"
,
initializer
=
uniform_initializer
(
self
.
init_scale
)),
bias_attr
=
False
)
def
_transpose_batch_time
(
self
,
x
):
return
fluid
.
layers
.
transpose
(
x
,
[
1
,
0
]
+
list
(
range
(
2
,
len
(
x
.
shape
))))
def
_merge_batch_beams
(
self
,
x
):
return
fluid
.
layers
.
reshape
(
x
,
shape
=
(
-
1
,
x
.
shape
[
2
]))
def
tile_beam_merge_with_batch
(
self
,
x
):
x
=
fluid
.
layers
.
unsqueeze
(
x
,
[
1
])
# [batch_size, 1, ...]
expand_times
=
[
1
]
*
len
(
x
.
shape
)
expand_times
[
1
]
=
self
.
beam_size
x
=
fluid
.
layers
.
expand
(
x
,
expand_times
)
# [batch_size, beam_size, ...]
x
=
fluid
.
layers
.
transpose
(
x
,
list
(
range
(
2
,
len
(
x
.
shape
)))
+
[
0
,
1
])
# [..., batch_size, beam_size]
# use 0 to copy to avoid wrong shape
x
=
fluid
.
layers
.
reshape
(
x
,
shape
=
[
0
]
*
(
len
(
x
.
shape
)
-
2
)
+
[
-
1
])
# [..., batch_size * beam_size]
x
=
fluid
.
layers
.
transpose
(
x
,
[
len
(
x
.
shape
)
-
1
]
+
list
(
range
(
0
,
len
(
x
.
shape
)
-
1
)))
# [batch_size * beam_size, ...]
return
x
def
_split_batch_beams
(
self
,
x
):
return
fluid
.
layers
.
reshape
(
x
,
shape
=
(
-
1
,
self
.
beam_size
,
x
.
shape
[
1
]))
def
_expand_to_beam_size
(
self
,
x
):
x
=
fluid
.
layers
.
unsqueeze
(
x
,
[
1
])
expand_times
=
[
1
]
*
len
(
x
.
shape
)
expand_times
[
1
]
=
self
.
beam_size
x
=
fluid
.
layers
.
expand
(
x
,
expand_times
)
return
x
def
_real_state
(
self
,
state
,
new_state
,
step_mask
):
new_state
=
fluid
.
layers
.
elementwise_mul
(
new_state
,
step_mask
,
axis
=
0
)
-
\
fluid
.
layers
.
elementwise_mul
(
state
,
(
step_mask
-
1
),
axis
=
0
)
return
new_state
def
_gather
(
self
,
x
,
indices
,
batch_pos
):
topk_coordinates
=
fluid
.
layers
.
stack
([
batch_pos
,
indices
],
axis
=
2
)
return
fluid
.
layers
.
gather_nd
(
x
,
topk_coordinates
)
def
attention
(
self
,
query
,
enc_output
,
mask
=
None
):
query
=
fluid
.
layers
.
unsqueeze
(
query
,
[
1
])
memory
=
self
.
attn_fc
(
enc_output
)
attn
=
fluid
.
layers
.
matmul
(
query
,
memory
,
transpose_y
=
True
)
if
mask
is
not
None
:
attn
=
fluid
.
layers
.
transpose
(
attn
,
[
1
,
0
,
2
])
attn
=
fluid
.
layers
.
elementwise_add
(
attn
,
mask
*
1000000000
,
-
1
)
attn
=
fluid
.
layers
.
transpose
(
attn
,
[
1
,
0
,
2
])
weight
=
fluid
.
layers
.
softmax
(
attn
)
weight_memory
=
fluid
.
layers
.
matmul
(
weight
,
memory
)
return
weight_memory
def
_change_size_for_array
(
self
,
func
,
array
):
print
(
" ^"
*
10
,
"_change_size_for_array"
)
print
(
"array : "
,
array
)
for
i
,
state
in
enumerate
(
array
):
fluid
.
layers
.
array_write
(
func
(
state
),
i
,
array
)
return
array
@
declarative
def
forward
(
self
,
inputs
):
src
,
tar
,
label
,
src_sequence_length
,
tar_sequence_length
=
inputs
if
src
.
shape
[
0
]
<
self
.
batch_size
:
self
.
batch_size
=
src
.
shape
[
0
]
src_emb
=
self
.
src_embeder
(
self
.
_transpose_batch_time
(
src
))
# NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully.
# Because nested list can't be transformed now.
enc_hidden_0
=
to_variable
(
np
.
zeros
(
(
self
.
batch_size
,
self
.
hidden_size
),
dtype
=
'float32'
))
enc_hidden_0
.
stop_gradient
=
True
enc_cell_0
=
to_variable
(
np
.
zeros
(
(
self
.
batch_size
,
self
.
hidden_size
),
dtype
=
'float32'
))
enc_hidden_0
.
stop_gradient
=
True
zero
=
fluid
.
layers
.
zeros
(
shape
=
[
1
],
dtype
=
"int64"
)
enc_hidden
=
fluid
.
layers
.
create_array
(
dtype
=
"float32"
)
enc_cell
=
fluid
.
layers
.
create_array
(
dtype
=
"float32"
)
for
i
in
range
(
self
.
num_layers
):
index
=
zero
+
i
enc_hidden
=
fluid
.
layers
.
array_write
(
enc_hidden_0
,
index
,
array
=
enc_hidden
)
enc_cell
=
fluid
.
layers
.
array_write
(
enc_cell_0
,
index
,
array
=
enc_cell
)
max_seq_len
=
src_emb
.
shape
[
0
]
enc_len_mask
=
fluid
.
layers
.
sequence_mask
(
src_sequence_length
,
maxlen
=
max_seq_len
,
dtype
=
"float32"
)
enc_padding_mask
=
(
enc_len_mask
-
1.0
)
enc_len_mask
=
fluid
.
layers
.
transpose
(
enc_len_mask
,
[
1
,
0
])
enc_outputs
=
[]
# TODO: Because diff exits if call while_loop in static graph.
# In while block, a Variable created in parent block participates in the calculation of gradient,
# the gradient is wrong because each step scope always returns the same value generated by last step.
for
p
in
range
(
max_seq_len
):
k
=
0
+
p
enc_step_input
=
src_emb
[
k
]
step_mask
=
enc_len_mask
[
k
]
new_enc_hidden
,
new_enc_cell
=
[],
[]
for
i
in
range
(
self
.
num_layers
):
enc_new_hidden
,
enc_new_cell
=
self
.
enc_units
[
i
](
enc_step_input
,
enc_hidden
[
i
],
enc_cell
[
i
])
if
self
.
dropout
!=
None
and
self
.
dropout
>
0.0
:
enc_step_input
=
fluid
.
layers
.
dropout
(
enc_new_hidden
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
else
:
enc_step_input
=
enc_new_hidden
new_enc_hidden
.
append
(
self
.
_real_state
(
enc_hidden
[
i
],
enc_new_hidden
,
step_mask
))
new_enc_cell
.
append
(
self
.
_real_state
(
enc_cell
[
i
],
enc_new_cell
,
step_mask
))
enc_outputs
.
append
(
enc_step_input
)
enc_hidden
,
enc_cell
=
new_enc_hidden
,
new_enc_cell
enc_outputs
=
fluid
.
layers
.
stack
(
enc_outputs
)
enc_outputs
=
self
.
_transpose_batch_time
(
enc_outputs
)
# train
input_feed
=
to_variable
(
np
.
zeros
(
(
self
.
batch_size
,
self
.
hidden_size
),
dtype
=
'float32'
))
# NOTE: set stop_gradient here, otherwise grad var is null
input_feed
.
stop_gradient
=
True
dec_hidden
,
dec_cell
=
enc_hidden
,
enc_cell
tar_emb
=
self
.
tar_embeder
(
self
.
_transpose_batch_time
(
tar
))
max_seq_len
=
tar_emb
.
shape
[
0
]
dec_output
=
[]
for
step_idx
in
range
(
max_seq_len
):
j
=
step_idx
+
0
step_input
=
tar_emb
[
j
]
step_input
=
fluid
.
layers
.
concat
([
step_input
,
input_feed
],
1
)
new_dec_hidden
,
new_dec_cell
=
[],
[]
for
i
in
range
(
self
.
num_layers
):
new_hidden
,
new_cell
=
self
.
dec_units
[
i
](
step_input
,
dec_hidden
[
i
],
dec_cell
[
i
])
new_dec_hidden
.
append
(
new_hidden
)
new_dec_cell
.
append
(
new_cell
)
if
self
.
dropout
!=
None
and
self
.
dropout
>
0.0
:
step_input
=
fluid
.
layers
.
dropout
(
new_hidden
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
else
:
step_input
=
new_hidden
dec_att
=
self
.
attention
(
step_input
,
enc_outputs
,
enc_padding_mask
)
dec_att
=
fluid
.
layers
.
squeeze
(
dec_att
,
[
1
])
concat_att_out
=
fluid
.
layers
.
concat
([
dec_att
,
step_input
],
1
)
out
=
self
.
concat_fc
(
concat_att_out
)
input_feed
=
out
dec_output
.
append
(
out
)
dec_hidden
,
dec_cell
=
new_dec_hidden
,
new_dec_cell
dec_output
=
fluid
.
layers
.
stack
(
dec_output
)
dec_output
=
self
.
fc
(
self
.
_transpose_batch_time
(
dec_output
))
loss
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
dec_output
,
label
=
label
,
soft_label
=
False
)
loss
=
fluid
.
layers
.
squeeze
(
loss
,
axes
=
[
2
])
max_tar_seq_len
=
fluid
.
layers
.
shape
(
tar
)[
1
]
tar_mask
=
fluid
.
layers
.
sequence_mask
(
tar_sequence_length
,
maxlen
=
max_tar_seq_len
,
dtype
=
'float32'
)
loss
=
loss
*
tar_mask
loss
=
fluid
.
layers
.
reduce_mean
(
loss
,
dim
=
[
0
])
loss
=
fluid
.
layers
.
reduce_sum
(
loss
)
return
loss
python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py
浏览文件 @
10d572a7
...
...
@@ -125,11 +125,13 @@ class Seq2SeqModelHyperParams(object):
max_grad_norm
=
5.0
# model path for model to save
model_path
=
"dy2stat/model/seq2seq"
base_model_path
=
"dy2stat/model/base_seq2seq"
attn_model_path
=
"dy2stat/model/attn_seq2seq"
# reload model to inference
reload_model
=
"model/epoch_0.pdparams"
beam_size
=
10
beam_size
=
4
max_seq_len
=
3
python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py
浏览文件 @
10d572a7
...
...
@@ -21,7 +21,7 @@ import paddle.fluid as fluid
from
paddle.fluid.clip
import
GradientClipByGlobalNorm
from
paddle.fluid.dygraph.dygraph_to_static
import
ProgramTranslator
from
seq2seq_dygraph_model
import
BaseModel
from
seq2seq_dygraph_model
import
BaseModel
,
AttentionModel
from
seq2seq_utils
import
Seq2SeqModelHyperParams
as
args
from
seq2seq_utils
import
get_data_iter
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
(
...
...
@@ -43,11 +43,21 @@ def prepare_input(batch):
return
inputs
,
np
.
sum
(
tar_mask
)
def
train
():
def
train
(
attn_model
=
False
):
with
fluid
.
dygraph
.
guard
(
place
):
fluid
.
default_startup_program
().
random_seed
=
2020
fluid
.
default_main_program
().
random_seed
=
2020
if
attn_model
:
model
=
AttentionModel
(
args
.
hidden_size
,
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
batch_size
,
num_layers
=
args
.
num_layers
,
init_scale
=
args
.
init_scale
,
dropout
=
args
.
dropout
)
else
:
model
=
BaseModel
(
args
.
hidden_size
,
args
.
src_vocab_size
,
...
...
@@ -88,17 +98,40 @@ def train():
"Batch:[%d]; Time: %.5f s; loss: %.5f; total_loss: %.5f; word num: %.5f; ppl: %.5f"
%
(
batch_id
,
batch_time
,
loss
.
numpy
(),
total_loss
.
numpy
(),
word_count
,
np
.
exp
(
total_loss
.
numpy
()
/
word_count
)))
if
attn_model
:
# NOTE: Please see code of AttentionModel.
# Because diff exits if call while_loop in static graph, only run 4 batches to pass the test temporarily.
if
batch_id
+
1
>=
4
:
break
else
:
if
batch_id
+
1
>=
STEP_NUM
:
break
model_dir
=
os
.
path
.
join
(
args
.
model_path
)
model_path
=
args
.
attn_model_path
if
attn_model
else
args
.
base_model_path
model_dir
=
os
.
path
.
join
(
model_path
)
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
fluid
.
save_dygraph
(
model
.
state_dict
(),
model_dir
)
return
loss
.
numpy
()
def
infer
():
def
infer
(
attn_model
=
False
):
with
fluid
.
dygraph
.
guard
(
place
):
if
attn_model
:
model
=
AttentionModel
(
args
.
hidden_size
,
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
batch_size
,
beam_size
=
args
.
beam_size
,
num_layers
=
args
.
num_layers
,
init_scale
=
args
.
init_scale
,
dropout
=
0.0
,
mode
=
'beam_search'
)
else
:
model
=
BaseModel
(
args
.
hidden_size
,
args
.
src_vocab_size
,
...
...
@@ -109,63 +142,64 @@ def infer():
init_scale
=
args
.
init_scale
,
dropout
=
0.0
,
mode
=
'beam_search'
)
state_dict
,
_
=
fluid
.
dygraph
.
load_dygraph
(
args
.
model_path
)
model_path
=
args
.
attn_model_path
if
attn_model
else
args
.
base_model_path
state_dict
,
_
=
fluid
.
dygraph
.
load_dygraph
(
model_path
)
model
.
set_dict
(
state_dict
)
model
.
eval
()
train_data_iter
=
get_data_iter
(
args
.
batch_size
,
mode
=
'infer'
)
batch_times
=
[]
for
batch_id
,
batch
in
enumerate
(
train_data_iter
):
batch_start_time
=
time
.
time
()
input_data_feed
,
word_num
=
prepare_input
(
batch
)
input_data_feed
=
[
fluid
.
dygraph
.
to_variable
(
np_inp
)
for
np_inp
in
input_data_feed
]
outputs
=
model
.
beam_search
(
input_data_feed
)
batch_end_time
=
time
.
time
()
batch_time
=
batch_end_time
-
batch_start_time
batch_times
.
append
(
batch_time
)
if
batch_id
>
STEP_NUM
:
break
return
outputs
.
numpy
()
class
TestSeq2seq
(
unittest
.
TestCase
):
def
run_dygraph
(
self
,
mode
=
"train"
):
def
run_dygraph
(
self
,
mode
=
"train"
,
attn_model
=
False
):
program_translator
.
enable
(
False
)
if
mode
==
"train"
:
return
train
()
return
train
(
attn_model
)
else
:
return
infer
()
return
infer
(
attn_model
)
def
run_static
(
self
,
mode
=
"train"
):
def
run_static
(
self
,
mode
=
"train"
,
attn_model
=
False
):
program_translator
.
enable
(
True
)
if
mode
==
"train"
:
return
train
()
return
train
(
attn_model
)
else
:
return
infer
()
return
infer
(
attn_model
)
def
_test_train
(
self
):
dygraph_loss
=
self
.
run_dygraph
(
mode
=
"train"
)
static_loss
=
self
.
run_static
(
mode
=
"train"
)
def
_test_train
(
self
,
attn_model
=
False
):
dygraph_loss
=
self
.
run_dygraph
(
mode
=
"train"
,
attn_model
=
attn_model
)
static_loss
=
self
.
run_static
(
mode
=
"train"
,
attn_model
=
attn_model
)
result
=
np
.
allclose
(
dygraph_loss
,
static_loss
)
self
.
assertTrue
(
result
,
msg
=
"
\n
dygraph_loss = {}
\n
static_loss = {}"
.
format
(
dygraph_loss
,
static_loss
))
def
_test_predict
(
self
):
pred_dygraph
=
self
.
run_dygraph
(
mode
=
"test"
)
pred_static
=
self
.
run_static
(
mode
=
"test"
)
def
_test_predict
(
self
,
attn_model
=
False
):
pred_dygraph
=
self
.
run_dygraph
(
mode
=
"test"
,
attn_model
=
attn_model
)
pred_static
=
self
.
run_static
(
mode
=
"test"
,
attn_model
=
attn_model
)
result
=
np
.
allclose
(
pred_static
,
pred_dygraph
)
self
.
assertTrue
(
result
,
msg
=
"
\n
pred_dygraph = {}
\n
pred_static = {}"
.
format
(
pred_dygraph
,
pred_static
))
def
test_check_result
(
self
):
self
.
_test_train
()
self
.
_test_predict
()
def
test_base_model
(
self
):
self
.
_test_train
(
attn_model
=
False
)
self
.
_test_predict
(
attn_model
=
False
)
def
test_attn_model
(
self
):
self
.
_test_train
(
attn_model
=
True
)
# TODO(liym27): add predict
# self._test_predict(attn_model=True)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录