提交 07126083 编写于 作者: G guosheng

Adapt to fill_batch_size_like and make outputs between fast_infer and the...

Adapt to fill_batch_size_like and make outputs between fast_infer and the original python infer alignment in Transformer
上级 3e9fccea
...@@ -115,6 +115,12 @@ def multi_head_attention(queries, ...@@ -115,6 +115,12 @@ def multi_head_attention(queries,
""" """
scaled_q = layers.scale(x=q, scale=d_model**-0.5) scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
# global FLAG
# if FLAG and attn_bias:
# print "hehehehehe"
# layers.Print(product, message="product")
# layers.Print(attn_bias, message="bias")
# FLAG = False
weights = layers.reshape( weights = layers.reshape(
x=layers.elementwise_add( x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product, x=product, y=attn_bias) if attn_bias else product,
...@@ -598,7 +604,7 @@ def wrap_decoder(trg_vocab_size, ...@@ -598,7 +604,7 @@ def wrap_decoder(trg_vocab_size,
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
act="softmax") # if dec_inputs is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -656,14 +662,9 @@ def fast_decode( ...@@ -656,14 +662,9 @@ def fast_decode(
with while_op.block(): with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx) pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx) pre_scores = layers.array_read(array=scores, i=step_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_ids, value=1, shape=[-1, 1], dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
pre_src_attn_bias = layers.sequence_expand( pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores) x=trg_src_attn_bias, y=pre_scores)
# layers.Print(pre_src_attn_bias)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores) pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{ pre_caches = [{
"k": layers.sequence_expand( "k": layers.sequence_expand(
...@@ -671,12 +672,21 @@ def fast_decode( ...@@ -671,12 +672,21 @@ def fast_decode(
"v": layers.sequence_expand( "v": layers.sequence_expand(
x=cache["v"], y=pre_scores), x=cache["v"], y=pre_scores),
} for cache in caches] } for cache in caches]
# layers.Print(pre_ids) pre_pos = layers.elementwise_mul(
# layers.Print(pre_pos) x=layers.fill_constant_batch_size_like(
# layers.Print(pre_enc_output) input=pre_enc_output, # cann't use pre_ids here since it has lod
# layers.Print(pre_src_attn_bias) value=1,
# layers.Print(pre_caches[0]["k"]) shape=[-1, 1],
# layers.Print(pre_caches[0]["v"]) dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
# layers.Print(pre_ids, summarize=10)
# layers.Print(pre_pos, summarize=10)
# layers.Print(pre_enc_output, summarize=10)
# layers.Print(pre_src_attn_bias, summarize=10)
# layers.Print(pre_caches[0]["k"], summarize=10)
# layers.Print(pre_caches[0]["v"], summarize=10)
# layers.Print(slf_attn_post_softmax_shape) # layers.Print(slf_attn_post_softmax_shape)
logits = wrap_decoder( logits = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -695,7 +705,8 @@ def fast_decode( ...@@ -695,7 +705,8 @@ def fast_decode(
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
# layers.Print(logits) # layers.Print(logits)
topk_scores, topk_indices = layers.topk(logits, k=beam_size) topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
# layers.Print(topk_scores) # layers.Print(topk_scores)
# layers.Print(topk_indices) # layers.Print(topk_indices)
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册