提交 1f5d2987 编写于 作者: G guosheng

Fix some bugs of dygraph Transformer.

上级 866d3e03
......@@ -21,8 +21,6 @@ import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from config import word_emb_param_names, pos_enc_param_names
def position_encoding_init(n_position, d_pos_vec):
"""
......@@ -138,7 +136,7 @@ class MultiHeadAttention(Layer):
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
if cache is not None and static_kv and cache.has_key("static_k"):
if cache is not None and static_kv and "static_k" in cache:
# for encoder-decoder attention in inference and has cached
k = cache["static_k"]
v = cache["static_v"]
......@@ -151,7 +149,7 @@ class MultiHeadAttention(Layer):
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None:
if static_kv and not cache.has_key("static_k"):
if static_kv and not "static_k" in cache:
# for encoder-decoder attention in inference and has not cached
cache["static_k"], cache["static_v"] = k, v
elif not static_kv:
......@@ -180,7 +178,7 @@ class MultiHeadAttention(Layer):
dropout_prob=self.dropout_rate,
is_test=False)
out = layers.matmul(weights, v)
out = layers.matmul(weights, v)
# combine heads
out = layers.transpose(out, perm=[0, 2, 1, 3])
......@@ -817,15 +815,15 @@ class Transformer(Layer):
return finished_seq, finished_scores
def beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
if beam_size == 1:
return self._greedy_search(src_word,
src_pos,
......@@ -1017,6 +1015,9 @@ class Transformer(Layer):
trg_word = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=bos_id)
finished = layers.fill_constant(shape=[batch_size, 1],
dtype="bool",
value=0)
## init states (caches) for transformer
caches = [{
......@@ -1045,7 +1046,8 @@ class Transformer(Layer):
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=1)
finished = layers.equal(topk_indices, end_token_tensor)
finished = layers.logical_or(
finished, layers.equal(topk_indices, end_token_tensor))
trg_word = topk_indices
log_probs = topk_scores
......
......@@ -145,7 +145,6 @@ def do_train(args):
# train loop
for pass_id in range(args.epoch):
pass_start_time = time.time()
avg_batch_time = time.time()
batch_id = 0
for input_data in train_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
......@@ -178,6 +177,7 @@ def do_train(args):
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
......@@ -186,6 +186,7 @@ def do_train(args):
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
......@@ -228,7 +229,6 @@ def do_train(args):
optimizer.state_dict(),
os.path.join(model_dir, "transformer"))
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册