提交 99437ee9 编写于 作者: G guosheng

Fix weight sharing in Transformer inference

上级 34601005
......@@ -98,7 +98,7 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
if hasattr(g_cfg, key):
try:
value = eval(value)
except SyntaxError: # for file path
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
......
......@@ -308,7 +308,7 @@ def infer(args):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
......@@ -317,7 +317,7 @@ def infer(args):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册