未验证 提交 6fc865e1 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1013 from guoshengCS/add-transformer-BeamsearchDecoder-clean

Add transformer beamsearch decoder
...@@ -38,7 +38,7 @@ class InferTaskConfig(object): ...@@ -38,7 +38,7 @@ class InferTaskConfig(object):
batch_size = 10 batch_size = 10
# the parameters for beam search. # the parameters for beam search.
beam_size = 5 beam_size = 5
max_length = 256 max_out_len = 256
# the number of decoded sentences to output. # the number of decoded sentences to output.
n_best = 1 n_best = 1
# the flags indicating whether to output the special tokens. # the flags indicating whether to output the special tokens.
...@@ -104,23 +104,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs): ...@@ -104,23 +104,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
break break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs. # Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in # The shapes here act as placeholder and are set to pass the infer-shape in
# compile time. # compile time.
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_word": [(batch_size * seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "src_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"], "src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention. # This shape input is used to reshape before softmax in self attention.
...@@ -129,24 +134,23 @@ input_descs = { ...@@ -129,24 +134,23 @@ input_descs = {
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"], "src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_word": [(batch_size * seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "trg_pos": [(batch_size * seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head, "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
(ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source # This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention. # input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is: # The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head, "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
(ModelHyperParams.max_length + 1), seq_len), "float32"],
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer. # This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"], "trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention. # This shape input is used to reshape before softmax in self attention.
...@@ -162,15 +166,18 @@ input_descs = { ...@@ -162,15 +166,18 @@ input_descs = {
# This input is used in independent decoder program for inference. # This input is used in independent decoder program for inference.
# The actual data shape of enc_output is: # The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model] # [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1), "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is: # The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], "lbl_word": [(batch_size * seq_len, 1L), "int64"],
# This input is used to mask out the loss of paddding tokens. # This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is: # The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], "lbl_weight": [(batch_size * seq_len, 1L), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
"trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
"init_score": [(batch_size, 1L), "float32"],
} }
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
...@@ -205,3 +212,12 @@ decoder_util_input_fields = ( ...@@ -205,3 +212,12 @@ decoder_util_input_fields = (
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", ) "lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
...@@ -7,6 +7,7 @@ import paddle.fluid as fluid ...@@ -7,6 +7,7 @@ import paddle.fluid as fluid
import model import model
from model import wrap_encoder as encoder from model import wrap_encoder as encoder
from model import wrap_decoder as decoder from model import wrap_decoder as decoder
from model import fast_decode as fast_decoder
from config import * from config import *
from train import pad_batch_data from train import pad_batch_data
import reader import reader
...@@ -87,7 +88,8 @@ def translate_batch(exe, ...@@ -87,7 +88,8 @@ def translate_batch(exe,
output_unk=True): output_unk=True):
""" """
Run the encoder program once and run the decoder program multiple times to Run the encoder program once and run the decoder program multiple times to
implement beam search externally. implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
""" """
# Prepare data for encoder and run the encoder. # Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data( enc_in_data = pad_batch_data(
...@@ -297,7 +299,32 @@ def translate_batch(exe, ...@@ -297,7 +299,32 @@ def translate_batch(exe,
return seqs, scores[:, :n_best].tolist() return seqs, scores[:, :n_best].tolist()
def infer(args): def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
def py_infer(test_data, trg_idx2word):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -341,49 +368,8 @@ def infer(args): ...@@ -341,49 +368,8 @@ def infer(args):
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode. # This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program( encoder_program = encoder_program.inference_optimize()
target_vars=[enc_output], main_program=encoder_program) decoder_program = decoder_program.inference_optimize()
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
...@@ -397,7 +383,7 @@ def infer(args): ...@@ -397,7 +383,7 @@ def infer(args):
(decoder_data_input_fields[-1], ), (decoder_data_input_fields[-1], ),
[predict.name], [predict.name],
InferTaskConfig.beam_size, InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.max_out_len,
InferTaskConfig.n_best, InferTaskConfig.n_best,
len(data), len(data),
ModelHyperParams.n_head, ModelHyperParams.n_head,
...@@ -416,6 +402,154 @@ def infer(args): ...@@ -416,6 +402,154 @@ def infer(args):
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
data_tensor.set(data, place)
if lod is not None:
data_tensor.set_lod(lod)
return data_tensor
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(
trg_word, dtype="float32"),
place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items())
return input_dict
def fast_infer(test_data, trg_idx2word):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
out_ids, out_scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=filter(lambda var: isinstance(var, fluid.framework.Parameter),
fluid.default_main_program().list_vars()))
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize()
for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
feed=data_input,
fetch_list=[out_ids, out_scores],
return_numpy=False)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best:
break
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
infer(args) infer(args)
...@@ -30,7 +30,8 @@ def multi_head_attention(queries, ...@@ -30,7 +30,8 @@ def multi_head_attention(queries,
n_head=1, n_head=1,
dropout_rate=0., dropout_rate=0.,
pre_softmax_shape=None, pre_softmax_shape=None,
post_softmax_shape=None): post_softmax_shape=None,
cache=None):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that computing softmax activiation to mask certain selected positions so that
...@@ -116,6 +117,10 @@ def multi_head_attention(queries, ...@@ -116,6 +117,10 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
q = __split_heads(q, n_head) q = __split_heads(q, n_head)
k = __split_heads(k, n_head) k = __split_heads(k, n_head)
v = __split_heads(v, n_head) v = __split_heads(v, n_head)
...@@ -203,7 +208,7 @@ def prepare_encoder(src_word, ...@@ -203,7 +208,7 @@ def prepare_encoder(src_word,
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape( enc_input = layers.reshape(
x=enc_input, x=enc_input,
shape=[-1, src_max_len, src_emb_dim], shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape) actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout_rate, enc_input, dropout_prob=dropout_rate,
...@@ -285,7 +290,8 @@ def decoder_layer(dec_input, ...@@ -285,7 +290,8 @@ def decoder_layer(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part. """ The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention. a multi-head attention is added to implement encoder-decoder attention.
...@@ -301,7 +307,8 @@ def decoder_layer(dec_input, ...@@ -301,7 +307,8 @@ def decoder_layer(dec_input,
n_head, n_head,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, ) slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer( slf_attn_output = post_process_layer(
dec_input, dec_input,
slf_attn_output, slf_attn_output,
...@@ -350,7 +357,8 @@ def decoder(dec_input, ...@@ -350,7 +357,8 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape=None, slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None, slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None, src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None): src_attn_post_softmax_shape=None,
caches=None):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
...@@ -369,7 +377,8 @@ def decoder(dec_input, ...@@ -369,7 +377,8 @@ def decoder(dec_input,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
...@@ -384,6 +393,8 @@ def make_all_inputs(input_fields): ...@@ -384,6 +393,8 @@ def make_all_inputs(input_fields):
name=input_field, name=input_field,
shape=input_descs[input_field][0], shape=input_descs[input_field][0],
dtype=input_descs[input_field][1], dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False) append_batch_size=False)
inputs.append(input_var) inputs.append(input_var)
return inputs return inputs
...@@ -517,7 +528,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -517,7 +528,8 @@ def wrap_decoder(trg_vocab_size,
dropout_rate, dropout_rate,
weight_sharing, weight_sharing,
dec_inputs=None, dec_inputs=None,
enc_output=None): enc_output=None,
caches=None):
""" """
The wrapper assembles together all needed layers for the decoder. The wrapper assembles together all needed layers for the decoder.
""" """
...@@ -559,7 +571,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -559,7 +571,8 @@ def wrap_decoder(trg_vocab_size,
slf_attn_pre_softmax_shape, slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
if weight_sharing: if weight_sharing:
predict = layers.reshape( predict = layers.reshape(
...@@ -578,3 +591,145 @@ def wrap_decoder(trg_vocab_size, ...@@ -578,3 +591,145 @@ def wrap_decoder(trg_vocab_size,
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
def fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
beam_size,
max_out_len,
eos_idx, ):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
def beam_search():
max_len = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
# computation in decoder.
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_scores),
"v": layers.sequence_expand(
x=cache["v"], y=pre_scores),
} for cache in caches]
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1,
shape=[-1, 1],
dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores),
y=layers.reshape(
pre_scores, shape=[-1]),
axis=0)
# beam_search op uses lod to distinguish branches.
topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
y=attn_pre_softmax_shape_delta),
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores
...@@ -198,7 +198,8 @@ class DataReader(object): ...@@ -198,7 +198,8 @@ class DataReader(object):
for line in f_obj: for line in f_obj:
fields = line.strip().split(self._delimiter) fields = line.strip().split(self._delimiter)
if len(fields) != 2 or (self._only_src and len(fields) != 1): if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue continue
sample_words = [] sample_words = []
...@@ -275,7 +276,7 @@ class DataReader(object): ...@@ -275,7 +276,7 @@ class DataReader(object):
for sample_idx in self._sample_idxs: for sample_idx in self._sample_idxs:
if self._only_src: if self._only_src:
yield (self._src_seq_ids[sample_idx]) yield (self._src_seq_ids[sample_idx], )
else: else:
yield (self._src_seq_ids[sample_idx], yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1], self._trg_seq_ids[sample_idx][:-1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册