提交 dead21e4 编写于 作者: G guosheng

Add beam search decoder in Transformer

上级 a3ed9b00
......@@ -103,22 +103,23 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break
batch_size = -1
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"src_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
[(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
......@@ -128,22 +129,22 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
"trg_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
......@@ -170,6 +171,8 @@ input_descs = {
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
# These two inputs are used for beam search decoder.
# "start_token": [(1 * 1, 1L), "int64"],
}
# Names of position encoding table which will be initialized externally.
......@@ -200,3 +203,7 @@ decoder_util_input_fields = (
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
fast_decoder_data_fields = (
"trg_word",
# "start_token",
"trg_src_attn_bias", )
......@@ -7,6 +7,7 @@ import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
......@@ -416,5 +417,15 @@ def infer(args):
if __name__ == "__main__":
fast_decoder(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_length,
ModelHyperParams.eos_idx)
print(fluid.default_main_program())
exit(0)
args = parse_args()
infer(args)
......@@ -30,7 +30,8 @@ def multi_head_attention(queries,
n_head=1,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
post_softmax_shape=None,
cache=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
......@@ -128,6 +129,12 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
print cache["k"].shape
print k.shape
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
......@@ -300,7 +307,8 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
......@@ -316,7 +324,8 @@ def decoder_layer(dec_input,
n_head,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
......@@ -365,26 +374,18 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
src_attn_post_softmax_shape=None,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
d_key, d_value, d_model, d_inner_hid, dropout_rate,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
if caches is None else caches[i])
dec_input = dec_output
return dec_output
......@@ -523,7 +524,8 @@ def wrap_decoder(trg_vocab_size,
d_inner_hid,
dropout_rate,
dec_inputs=None,
enc_output=None):
enc_output=None,
caches=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
......@@ -563,7 +565,8 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference.
predict = layers.reshape(
x=layers.fc(input=dec_output,
......@@ -573,3 +576,112 @@ def wrap_decoder(trg_vocab_size,
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
return predict
def fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
beam_size,
max_out_len,
eos_idx, ):
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate)
start_tokens, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \
make_all_inputs(fast_decoder_data_fields + decoder_util_input_fields)
def beam_search():
cond = layers.create_tensor(dtype='bool')
while_op = layers.While(cond)
max_len = layers.fill_constant(
shape=[1], dtype='int32', value=max_out_len)
step_idx = layers.fill_constant(shape=[1], dtype='int32', value=0)
init_scores = layers.fill_constant_batch_size_like(
input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
# array states
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states (can be overwrited)
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
value=0)
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_ids, value=1, shape=[-1, 1], dtype='int32'),
y=layers.increment(
x=step_idx, value=1.0, in_place=False))
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_ids)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_ids)
print caches[0]["k"].shape
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_ids),
"v": layers.sequence_expand(
x=cache["v"], y=pre_ids),
} for cache in caches]
print pre_caches[0]["k"].shape
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
topk_scores, topk_indices = layers.topk(logits, k=beam_size)
accu_scores = layers.elementwise_add(
x=pre_scores, y=layers.log(x=layers.softmax(topk_scores)))
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx)
layers.array_write(selected_scores, i=step_idx)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
max_len_cond = layers.less_than(x=step_idx, y=max_len)
all_finish_cond = layers.less_than(x=step_idx, y=max_len)
layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)
beam_search()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册