From 1f5d29873424147204ac392a63b37b89479871c6 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 3 Feb 2020 10:05:28 +0800 Subject: [PATCH] Fix some bugs of dygraph Transformer. --- dygraph/transformer/model.py | 32 +++++++++++++++++--------------- dygraph/transformer/train.py | 4 ++-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/dygraph/transformer/model.py b/dygraph/transformer/model.py index 25614f15..b4ae428e 100644 --- a/dygraph/transformer/model.py +++ b/dygraph/transformer/model.py @@ -21,8 +21,6 @@ import paddle.fluid.layers as layers from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay -from config import word_emb_param_names, pos_enc_param_names - def position_encoding_init(n_position, d_pos_vec): """ @@ -138,7 +136,7 @@ class MultiHeadAttention(Layer): q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) q = layers.transpose(x=q, perm=[0, 2, 1, 3]) - if cache is not None and static_kv and cache.has_key("static_k"): + if cache is not None and static_kv and "static_k" in cache: # for encoder-decoder attention in inference and has cached k = cache["static_k"] v = cache["static_v"] @@ -151,7 +149,7 @@ class MultiHeadAttention(Layer): v = layers.transpose(x=v, perm=[0, 2, 1, 3]) if cache is not None: - if static_kv and not cache.has_key("static_k"): + if static_kv and not "static_k" in cache: # for encoder-decoder attention in inference and has not cached cache["static_k"], cache["static_v"] = k, v elif not static_kv: @@ -180,7 +178,7 @@ class MultiHeadAttention(Layer): dropout_prob=self.dropout_rate, is_test=False) - out = layers.matmul(weights, v) + out = layers.matmul(weights, v) # combine heads out = layers.transpose(out, perm=[0, 2, 1, 3]) @@ -817,15 +815,15 @@ class Transformer(Layer): return finished_seq, finished_scores def beam_search(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - beam_size=4, - max_len=256): + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + beam_size=4, + max_len=256): if beam_size == 1: return self._greedy_search(src_word, src_pos, @@ -1017,6 +1015,9 @@ class Transformer(Layer): trg_word = layers.fill_constant(shape=[batch_size, 1], dtype="int64", value=bos_id) + finished = layers.fill_constant(shape=[batch_size, 1], + dtype="bool", + value=0) ## init states (caches) for transformer caches = [{ @@ -1045,7 +1046,8 @@ class Transformer(Layer): scores = log_probs topk_scores, topk_indices = layers.topk(input=scores, k=1) - finished = layers.equal(topk_indices, end_token_tensor) + finished = layers.logical_or( + finished, layers.equal(topk_indices, end_token_tensor)) trg_word = topk_indices log_probs = topk_scores diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 40d675b4..7f48f10a 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -145,7 +145,6 @@ def do_train(args): # train loop for pass_id in range(args.epoch): pass_start_time = time.time() - avg_batch_time = time.time() batch_id = 0 for input_data in train_loader(): (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, @@ -178,6 +177,7 @@ def do_train(args): (step_idx, pass_id, batch_id, total_avg_cost, total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]))) + avg_batch_time = time.time() else: logging.info( "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " @@ -186,6 +186,7 @@ def do_train(args): total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]), args.print_step / (time.time() - avg_batch_time))) + avg_batch_time = time.time() if step_idx % args.save_step == 0 and step_idx != 0: @@ -228,7 +229,6 @@ def do_train(args): optimizer.state_dict(), os.path.join(model_dir, "transformer")) - avg_batch_time = time.time() batch_id += 1 step_idx += 1 -- GitLab