提交 dead21e4 编写于 作者: G guosheng

Add beam search decoder in Transformer

上级 a3ed9b00
...@@ -103,22 +103,23 @@ def merge_cfg_from_list(cfg_list, g_cfgs): ...@@ -103,22 +103,23 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break break
batch_size = -1
# Here list the data shapes and data types of all inputs. # Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in # The shapes here act as placeholder and are set to pass the infer-shape in
# compile time. # compile time.
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": "src_slf_attn_bias":
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), [(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"], (ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"], "src_data_shape": [(3L, ), "int32"],
...@@ -128,22 +129,22 @@ input_descs = { ...@@ -128,22 +129,22 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"], "src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head, "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1), (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"], (ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source # This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention. # input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is: # The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head, "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1), (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"], (ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
...@@ -170,6 +171,8 @@ input_descs = { ...@@ -170,6 +171,8 @@ input_descs = {
# The actual data shape of label_weight is: # The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], "lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
# These two inputs are used for beam search decoder.
# "start_token": [(1 * 1, 1L), "int64"],
} }
# Names of position encoding table which will be initialized externally. # Names of position encoding table which will be initialized externally.
...@@ -200,3 +203,7 @@ decoder_util_input_fields = ( ...@@ -200,3 +203,7 @@ decoder_util_input_fields = (
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", ) "lbl_weight", )
fast_decoder_data_fields = (
"trg_word",
# "start_token",
"trg_src_attn_bias", )
...@@ -7,6 +7,7 @@ import paddle.fluid as fluid ...@@ -7,6 +7,7 @@ import paddle.fluid as fluid
import model import model
from model import wrap_encoder as encoder from model import wrap_encoder as encoder
from model import wrap_decoder as decoder from model import wrap_decoder as decoder
from model import fast_decode as fast_decoder
from config import * from config import *
from train import pad_batch_data from train import pad_batch_data
import reader import reader
...@@ -416,5 +417,15 @@ def infer(args): ...@@ -416,5 +417,15 @@ def infer(args):
if __name__ == "__main__": if __name__ == "__main__":
fast_decoder(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
InferTaskConfig.beam_size, InferTaskConfig.max_length,
ModelHyperParams.eos_idx)
print(fluid.default_main_program())
exit(0)
args = parse_args() args = parse_args()
infer(args) infer(args)
...@@ -30,7 +30,8 @@ def multi_head_attention(queries, ...@@ -30,7 +30,8 @@ def multi_head_attention(queries,
n_head=1, n_head=1,
dropout_rate=0., dropout_rate=0.,
pre_softmax_shape=None, pre_softmax_shape=None,
post_softmax_shape=None): post_softmax_shape=None,
cache=None):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that computing softmax activiation to mask certain selected positions so that
...@@ -128,6 +129,12 @@ def multi_head_attention(queries, ...@@ -128,6 +129,12 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
print cache["k"].shape
print k.shape
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head) q = __split_heads(q, n_head)
k = __split_heads(k, n_head) k = __split_heads(k, n_head)
v = __split_heads(v, n_head) v = __split_heads(v, n_head)
...@@ -300,7 +307,8 @@ def decoder_layer(dec_input, ...@@ -300,7 +307,8 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part. """ The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention. a multi-head attention is added to implement encoder-decoder attention.
...@@ -316,7 +324,8 @@ def decoder_layer(dec_input, ...@@ -316,7 +324,8 @@ def decoder_layer(dec_input,
n_head, n_head,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, ) slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer( slf_attn_output = post_process_layer(
dec_input, dec_input,
slf_attn_output, slf_attn_output,
...@@ -365,26 +374,18 @@ def decoder(dec_input, ...@@ -365,26 +374,18 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
caches=None):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
for i in range(n_layer): for i in range(n_layer):
dec_output = decoder_layer( dec_output = decoder_layer(
dec_input, dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head,
enc_output, d_key, d_value, d_model, d_inner_hid, dropout_rate,
dec_slf_attn_bias, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
dec_enc_attn_bias, src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None
n_head, if caches is None else caches[i])
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
...@@ -523,7 +524,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -523,7 +524,8 @@ def wrap_decoder(trg_vocab_size,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
dec_inputs=None, dec_inputs=None,
enc_output=None): enc_output=None,
caches=None):
""" """
The wrapper assembles together all needed layers for the decoder. The wrapper assembles together all needed layers for the decoder.
""" """
...@@ -563,7 +565,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -563,7 +565,8 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
predict = layers.reshape( predict = layers.reshape(
x=layers.fc(input=dec_output, x=layers.fc(input=dec_output,
...@@ -573,3 +576,112 @@ def wrap_decoder(trg_vocab_size, ...@@ -573,3 +576,112 @@ def wrap_decoder(trg_vocab_size,
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
def fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
beam_size,
max_out_len,
eos_idx, ):
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate)
start_tokens, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \
make_all_inputs(fast_decoder_data_fields + decoder_util_input_fields)
def beam_search():
cond = layers.create_tensor(dtype='bool')
while_op = layers.While(cond)
max_len = layers.fill_constant(
shape=[1], dtype='int32', value=max_out_len)
step_idx = layers.fill_constant(shape=[1], dtype='int32', value=0)
init_scores = layers.fill_constant_batch_size_like(
input=start_tokens, shape=[-1, 1], dtype="float32", value=0)
# array states
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states (can be overwrited)
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype="float32",
value=0)
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_ids, value=1, shape=[-1, 1], dtype='int32'),
y=layers.increment(
x=step_idx, value=1.0, in_place=False))
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_ids)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_ids)
print caches[0]["k"].shape
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_ids),
"v": layers.sequence_expand(
x=cache["v"], y=pre_ids),
} for cache in caches]
print pre_caches[0]["k"].shape
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
topk_scores, topk_indices = layers.topk(logits, k=beam_size)
accu_scores = layers.elementwise_add(
x=pre_scores, y=layers.log(x=layers.softmax(topk_scores)))
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx)
layers.array_write(selected_scores, i=step_idx)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
max_len_cond = layers.less_than(x=step_idx, y=max_len)
all_finish_cond = layers.less_than(x=step_idx, y=max_len)
layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)
beam_search()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册