提交 31186c41 编写于 作者: H hong 提交者: Aurelius84

update transformer support remove build once (#4115)

* update transformer support remove build once; test=develop

* fix optimizer; test=develop
上级 e4047478
此差异已折叠。
......@@ -62,9 +62,9 @@ def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word = trg_word.reshape(-1, 1, 1 )
src_word = src_word.reshape(-1, src_max_len, 1 )
src_pos = src_pos.reshape(-1, src_max_len,1 )
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
......@@ -101,7 +101,7 @@ def infer(args):
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
transformer = TransFormer(
'transformer', ModelHyperParams.src_vocab_size,
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
......@@ -129,7 +129,8 @@ def infer(args):
enc_inputs, dec_inputs = prepare_infer_input(
batch, ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head)
print( "enc inputs", enc_inputs[0].shape )
finished_seq, finished_scores = transformer.beam_search(
enc_inputs,
dec_inputs,
......
......@@ -110,7 +110,7 @@ def train(args):
# define model
transformer = TransFormer(
'transformer', ModelHyperParams.src_vocab_size,
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
......@@ -123,6 +123,7 @@ def train(args):
optimizer = fluid.optimizer.Adam(learning_rate=NoamDecay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate),
parameter_list = transformer.parameters(),
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册