diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 8ab6b2de4ff7a44e807002c79b8be46f4a912920..a4e588c620f21c4f38eb1906f55d68ddf93214b6 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -38,7 +38,7 @@ class InferTaskConfig(object): batch_size = 10 # the parameters for beam search. beam_size = 5 - max_length = 256 + max_out_len = 256 # the number of decoded sentences to output. n_best = 1 # the flags indicating whether to output the special tokens. @@ -104,23 +104,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs): break +# The placeholder for batch_size in compile time. Must be -1 currently to be +# consistent with some ops' infer-shape output in compile time, such as the +# sequence_expand op used in beamsearch decoder. +batch_size = -1 +# The placeholder for squence length in compile time. +seq_len = ModelHyperParams.max_length # Here list the data shapes and data types of all inputs. # The shapes here act as placeholder and are set to pass the infer-shape in # compile time. input_descs = { # The actual data shape of src_word is: # [batch_size * max_src_len_in_batch, 1] - "src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_word": [(batch_size * seq_len, 1L), "int64", 2], # The actual data shape of src_pos is: # [batch_size * max_src_len_in_batch, 1] - "src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_pos": [(batch_size * seq_len, 1L), "int64"], # This input is used to remove attention weights on paddings in the # encoder. # The actual data shape of src_slf_attn_bias is: # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] - "src_slf_attn_bias": - [(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This shape input is used to reshape the output of embedding layer. "src_data_shape": [(3L, ), "int32"], # This shape input is used to reshape before softmax in self attention. @@ -129,24 +134,23 @@ input_descs = { "src_slf_attn_post_softmax_shape": [(4L, ), "int32"], # The actual data shape of trg_word is: # [batch_size * max_trg_len_in_batch, 1] - "trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_word": [(batch_size * seq_len, 1L), "int64", + 2], # lod_level is only used in fast decoder. # The actual data shape of trg_pos is: # [batch_size * max_trg_len_in_batch, 1] - "trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_pos": [(batch_size * seq_len, 1L), "int64"], # This input is used to remove attention weights on paddings and # subsequent words in the decoder. # The actual data shape of trg_slf_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": [(1, ModelHyperParams.n_head, - (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This input is used to remove attention weights on paddings of the source # input in the encoder-decoder attention. # The actual data shape of trg_src_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": [(1, ModelHyperParams.n_head, - (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This shape input is used to reshape the output of embedding layer. "trg_data_shape": [(3L, ), "int32"], # This shape input is used to reshape before softmax in self attention. @@ -162,15 +166,18 @@ input_descs = { # This input is used in independent decoder program for inference. # The actual data shape of enc_output is: # [batch_size, max_src_len_in_batch, d_model] - "enc_output": [(1, (ModelHyperParams.max_length + 1), - ModelHyperParams.d_model), "float32"], + "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], # The actual data shape of label_word is: # [batch_size * max_trg_len_in_batch, 1] - "lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "lbl_word": [(batch_size * seq_len, 1L), "int64"], # This input is used to mask out the loss of paddding tokens. # The actual data shape of label_weight is: # [batch_size * max_trg_len_in_batch, 1] - "lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], + "lbl_weight": [(batch_size * seq_len, 1L), "float32"], + # These inputs are used to change the shape tensor in beam-search decoder. + "trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"], + "trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"], + "init_score": [(batch_size, 1L), "float32"], } # Names of word embedding table which might be reused for weight sharing. @@ -205,3 +212,12 @@ decoder_util_input_fields = ( label_data_input_fields = ( "lbl_word", "lbl_weight", ) +# In fast decoder, trg_pos (only containing the current time step) is generated +# by ops and trg_slf_attn_bias is not needed. +fast_decoder_data_input_fields = ( + "trg_word", + "init_score", + "trg_src_attn_bias", ) +fast_decoder_util_input_fields = decoder_util_input_fields + ( + "trg_slf_attn_pre_softmax_shape_delta", + "trg_slf_attn_post_softmax_shape_delta", ) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index c3d1f0af5d319f838e968930ec7d4083baeaab6f..874028081cca218ae16559af9ea9b05d3494c977 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -7,6 +7,7 @@ import paddle.fluid as fluid import model from model import wrap_encoder as encoder from model import wrap_decoder as decoder +from model import fast_decode as fast_decoder from config import * from train import pad_batch_data import reader @@ -87,7 +88,8 @@ def translate_batch(exe, output_unk=True): """ Run the encoder program once and run the decoder program multiple times to - implement beam search externally. + implement beam search externally. This is deprecated since a faster beam + search decoder based solely on Fluid operators has been added. """ # Prepare data for encoder and run the encoder. enc_in_data = pad_batch_data( @@ -297,7 +299,32 @@ def translate_batch(exe, return seqs, scores[:, :n_best].tolist() -def infer(args): +def post_process_seq(seq, + bos_idx=ModelHyperParams.bos_idx, + eos_idx=ModelHyperParams.eos_idx, + output_bos=InferTaskConfig.output_bos, + output_eos=InferTaskConfig.output_eos): + """ + Post-process the beam-search decoded sequence. Truncate from the first + and remove the and tokens currently. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = seq[:eos_pos + 1] + return filter( + lambda idx: (output_bos or idx != bos_idx) and \ + (output_eos or idx != eos_idx), + seq) + + +def py_infer(test_data, trg_idx2word): + """ + Inference by beam search implented by python, while the calculations from + symbols to probilities execute by Fluid operators. + """ place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) @@ -341,49 +368,8 @@ def infer(args): fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) # This is used here to set dropout to the test mode. - encoder_program = fluid.io.get_inference_program( - target_vars=[enc_output], main_program=encoder_program) - decoder_program = fluid.io.get_inference_program( - target_vars=[predict], main_program=decoder_program) - - test_data = reader.DataReader( - src_vocab_fpath=args.src_vocab_fpath, - trg_vocab_fpath=args.trg_vocab_fpath, - fpattern=args.test_file_pattern, - batch_size=args.batch_size, - use_token_batch=False, - pool_size=args.pool_size, - sort_type=reader.SortType.NONE, - shuffle=False, - shuffle_batch=False, - start_mark=args.special_token[0], - end_mark=args.special_token[1], - unk_mark=args.special_token[2], - max_length=ModelHyperParams.max_length, - clip_last_batch=False) - - trg_idx2word = test_data.load_dict( - dict_path=args.trg_vocab_fpath, reverse=True) - - def post_process_seq(seq, - bos_idx=ModelHyperParams.bos_idx, - eos_idx=ModelHyperParams.eos_idx, - output_bos=InferTaskConfig.output_bos, - output_eos=InferTaskConfig.output_eos): - """ - Post-process the beam-search decoded sequence. Truncate from the first - and remove the and tokens currently. - """ - eos_pos = len(seq) - 1 - for i, idx in enumerate(seq): - if idx == eos_idx: - eos_pos = i - break - seq = seq[:eos_pos + 1] - return filter( - lambda idx: (output_bos or idx != bos_idx) and \ - (output_eos or idx != eos_idx), - seq) + encoder_program = encoder_program.inference_optimize() + decoder_program = decoder_program.inference_optimize() for batch_id, data in enumerate(test_data.batch_generator()): batch_seqs, batch_scores = translate_batch( @@ -397,7 +383,7 @@ def infer(args): (decoder_data_input_fields[-1], ), [predict.name], InferTaskConfig.beam_size, - InferTaskConfig.max_length, + InferTaskConfig.max_out_len, InferTaskConfig.n_best, len(data), ModelHyperParams.n_head, @@ -416,6 +402,154 @@ def infer(args): print(" ".join([trg_idx2word[idx] for idx in seq])) +def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, + bos_idx, n_head, d_model, place): + """ + Put all padded data needed by beam search decoder into a dict. + """ + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + # start tokens + trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, 1, 1]).astype("float32") + + # These shape tensors are used in reshape_op. + src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32") + trg_data_shape = np.array([-1, 1, d_model], dtype="int32") + src_slf_attn_pre_softmax_shape = np.array( + [-1, src_slf_attn_bias.shape[-1]], dtype="int32") + src_slf_attn_post_softmax_shape = np.array( + [-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32") + trg_slf_attn_pre_softmax_shape = np.array( + [-1, 1], dtype="int32") # only the first time step + trg_slf_attn_post_softmax_shape = np.array( + [-1, n_head, 1, 1], dtype="int32") # only the first time step + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + [-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32") + # These inputs are used to change the shapes in the loop of while op. + attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32") + attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32") + + def to_lodtensor(data, place, lod=None): + data_tensor = fluid.LoDTensor() + data_tensor.set(data, place) + if lod is not None: + data_tensor.set_lod(lod) + return data_tensor + + # beamsearch_op must use tensors with lod + init_score = to_lodtensor( + np.zeros_like( + trg_word, dtype="float32"), + place, [range(trg_word.shape[0] + 1)] * 2) + trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2) + + data_input_dict = dict( + zip(data_input_names, [ + src_word, src_pos, src_slf_attn_bias, trg_word, init_score, + trg_src_attn_bias + ])) + util_input_dict = dict( + zip(util_input_names, [ + src_data_shape, src_slf_attn_pre_softmax_shape, + src_slf_attn_post_softmax_shape, trg_data_shape, + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, + attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta + ])) + + input_dict = dict(data_input_dict.items() + util_input_dict.items()) + return input_dict + + +def fast_infer(test_data, trg_idx2word): + """ + Inference by beam search decoder based solely on Fluid operators. + """ + place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + out_ids, out_scores = fast_decoder( + ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.weight_sharing, InferTaskConfig.beam_size, + InferTaskConfig.max_out_len, ModelHyperParams.eos_idx) + + fluid.io.load_vars( + exe, + InferTaskConfig.model_path, + vars=filter(lambda var: isinstance(var, fluid.framework.Parameter), + fluid.default_main_program().list_vars())) + + # This is used here to set dropout to the test mode. + infer_program = fluid.default_main_program().inference_optimize() + + for batch_id, data in enumerate(test_data.batch_generator()): + data_input = prepare_batch_input( + data, encoder_data_input_fields + fast_decoder_data_input_fields, + encoder_util_input_fields + fast_decoder_util_input_fields, + ModelHyperParams.eos_idx, ModelHyperParams.bos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model, place) + seq_ids, seq_scores = exe.run(infer_program, + feed=data_input, + fetch_list=[out_ids, out_scores], + return_numpy=False) + # How to parse the results: + # Suppose the lod of seq_ids is: + # [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]] + # then from lod[0]: + # there are 2 source sentences, beam width is 3. + # from lod[1]: + # the first source sentence has 3 hyps; the lengths are 12, 12, 16 + # the second source sentence has 3 hyps; the lengths are 14, 13, 15 + hyps = [[] for i in range(len(data))] + scores = [[] for i in range(len(data))] + for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence + start = seq_ids.lod()[0][i] + end = seq_ids.lod()[0][i + 1] + for j in range(end - start): # for each candidate + sub_start = seq_ids.lod()[1][start + j] + sub_end = seq_ids.lod()[1][start + j + 1] + hyps[i].append(" ".join([ + trg_idx2word[idx] + for idx in post_process_seq( + np.array(seq_ids)[sub_start:sub_end]) + ])) + scores[i].append(np.array(seq_scores)[sub_end - 1]) + print hyps[i][-1] + if len(hyps[i]) >= InferTaskConfig.n_best: + break + + +def infer(args, inferencer=fast_infer): + place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + test_data = reader.DataReader( + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + fpattern=args.test_file_pattern, + batch_size=args.batch_size, + use_token_batch=False, + pool_size=args.pool_size, + sort_type=reader.SortType.NONE, + shuffle=False, + shuffle_batch=False, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, + clip_last_batch=False) + trg_idx2word = test_data.load_dict( + dict_path=args.trg_vocab_fpath, reverse=True) + inferencer(test_data, trg_idx2word) + + if __name__ == "__main__": args = parse_args() infer(args) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 7756d633fb05d27904f84dc9c41e25643c17eb04..46c9f7a9065765b1e5ab5fa4d66042fc3312f75a 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -30,7 +30,8 @@ def multi_head_attention(queries, n_head=1, dropout_rate=0., pre_softmax_shape=None, - post_softmax_shape=None): + post_softmax_shape=None, + cache=None): """ Multi-Head Attention. Note that attn_bias is added to the logit before computing softmax activiation to mask certain selected positions so that @@ -116,6 +117,10 @@ def multi_head_attention(queries, q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + if cache is not None: # use cache and concat time steps + k = cache["k"] = layers.concat([cache["k"], k], axis=1) + v = cache["v"] = layers.concat([cache["v"], v], axis=1) + q = __split_heads(q, n_head) k = __split_heads(k, n_head) v = __split_heads(v, n_head) @@ -203,7 +208,7 @@ def prepare_encoder(src_word, enc_input = src_word_emb + src_pos_enc enc_input = layers.reshape( x=enc_input, - shape=[-1, src_max_len, src_emb_dim], + shape=[batch_size, seq_len, src_emb_dim], actual_shape=src_data_shape) return layers.dropout( enc_input, dropout_prob=dropout_rate, @@ -285,7 +290,8 @@ def decoder_layer(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + cache=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except a multi-head attention is added to implement encoder-decoder attention. @@ -301,7 +307,8 @@ def decoder_layer(dec_input, n_head, dropout_rate, slf_attn_pre_softmax_shape, - slf_attn_post_softmax_shape, ) + slf_attn_post_softmax_shape, + cache, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, @@ -350,7 +357,8 @@ def decoder(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + caches=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ @@ -369,7 +377,8 @@ def decoder(dec_input, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + src_attn_post_softmax_shape, + None if caches is None else caches[i], ) dec_input = dec_output return dec_output @@ -384,6 +393,8 @@ def make_all_inputs(input_fields): name=input_field, shape=input_descs[input_field][0], dtype=input_descs[input_field][1], + lod_level=input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0, append_batch_size=False) inputs.append(input_var) return inputs @@ -517,7 +528,8 @@ def wrap_decoder(trg_vocab_size, dropout_rate, weight_sharing, dec_inputs=None, - enc_output=None): + enc_output=None, + caches=None): """ The wrapper assembles together all needed layers for the decoder. """ @@ -559,7 +571,8 @@ def wrap_decoder(trg_vocab_size, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + src_attn_post_softmax_shape, + caches, ) # Return logits for training and probs for inference. if weight_sharing: predict = layers.reshape( @@ -578,3 +591,145 @@ def wrap_decoder(trg_vocab_size, shape=[-1, trg_vocab_size], act="softmax" if dec_inputs is None else None) return predict + + +def fast_decode( + src_vocab_size, + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + weight_sharing, + beam_size, + max_out_len, + eos_idx, ): + """ + Use beam search to decode. Caches will be used to store states of history + steps which can make the decoding faster. + """ + enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, + d_key, d_value, d_model, d_inner_hid, + dropout_rate, weight_sharing) + start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ + attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \ + make_all_inputs(fast_decoder_data_input_fields + + fast_decoder_util_input_fields) + + def beam_search(): + max_len = layers.fill_constant( + shape=[1], dtype=start_tokens.dtype, value=max_out_len) + step_idx = layers.fill_constant( + shape=[1], dtype=start_tokens.dtype, value=0) + cond = layers.less_than(x=step_idx, y=max_len) + while_op = layers.While(cond) + # array states will be stored for each step. + ids = layers.array_write(start_tokens, step_idx) + scores = layers.array_write(init_scores, step_idx) + # cell states will be overwrited at each step. + # caches contains states of history steps to reduce redundant + # computation in decoder. + caches = [{ + "k": layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, 0, d_model], + dtype=enc_output.dtype, + value=0), + "v": layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, 0, d_model], + dtype=enc_output.dtype, + value=0) + } for i in range(n_layer)] + with while_op.block(): + pre_ids = layers.array_read(array=ids, i=step_idx) + pre_scores = layers.array_read(array=scores, i=step_idx) + # sequence_expand can gather sequences according to lod thus can be + # used in beam search to sift states corresponding to selected ids. + pre_src_attn_bias = layers.sequence_expand( + x=trg_src_attn_bias, y=pre_scores) + pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores) + pre_caches = [{ + "k": layers.sequence_expand( + x=cache["k"], y=pre_scores), + "v": layers.sequence_expand( + x=cache["v"], y=pre_scores), + } for cache in caches] + pre_pos = layers.elementwise_mul( + x=layers.fill_constant_batch_size_like( + input=pre_enc_output, # cann't use pre_ids here since it has lod + value=1, + shape=[-1, 1], + dtype=pre_ids.dtype), + y=layers.increment( + x=step_idx, value=1.0, in_place=False), + axis=0) + logits = wrap_decoder( + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + weight_sharing, + dec_inputs=( + pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape, + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, src_attn_post_softmax_shape), + enc_output=pre_enc_output, + caches=pre_caches) + topk_scores, topk_indices = layers.topk( + input=layers.softmax(logits), k=beam_size) + accu_scores = layers.elementwise_add( + x=layers.log(topk_scores), + y=layers.reshape( + pre_scores, shape=[-1]), + axis=0) + # beam_search op uses lod to distinguish branches. + topk_indices = layers.lod_reset(topk_indices, pre_ids) + selected_ids, selected_scores = layers.beam_search( + pre_ids=pre_ids, + pre_scores=pre_scores, + ids=topk_indices, + scores=accu_scores, + beam_size=beam_size, + end_id=eos_idx) + layers.increment(x=step_idx, value=1.0, in_place=True) + # update states + layers.array_write(selected_ids, i=step_idx, array=ids) + layers.array_write(selected_scores, i=step_idx, array=scores) + layers.assign(pre_src_attn_bias, trg_src_attn_bias) + layers.assign(pre_enc_output, enc_output) + for i in range(n_layer): + layers.assign(pre_caches[i]["k"], caches[i]["k"]) + layers.assign(pre_caches[i]["v"], caches[i]["v"]) + layers.assign( + layers.elementwise_add( + x=slf_attn_pre_softmax_shape, + y=attn_pre_softmax_shape_delta), + slf_attn_pre_softmax_shape) + layers.assign( + layers.elementwise_add( + x=slf_attn_post_softmax_shape, + y=attn_post_softmax_shape_delta), + slf_attn_post_softmax_shape) + + length_cond = layers.less_than(x=step_idx, y=max_len) + finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) + layers.logical_and(x=length_cond, y=finish_cond, out=cond) + + finished_ids, finished_scores = layers.beam_search_decode( + ids, scores, beam_size=beam_size, end_id=eos_idx) + return finished_ids, finished_scores + + finished_ids, finished_scores = beam_search() + return finished_ids, finished_scores diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 900ca9d0702b79f363b5e507e7ec8050764efce9..27bd82b13a0480e80bdfcdc72eaa670854f4cd3a 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -198,7 +198,8 @@ class DataReader(object): for line in f_obj: fields = line.strip().split(self._delimiter) - if len(fields) != 2 or (self._only_src and len(fields) != 1): + if (not self._only_src and len(fields) != 2) or (self._only_src and + len(fields) != 1): continue sample_words = [] @@ -275,7 +276,7 @@ class DataReader(object): for sample_idx in self._sample_idxs: if self._only_src: - yield (self._src_seq_ids[sample_idx]) + yield (self._src_seq_ids[sample_idx], ) else: yield (self._src_seq_ids[sample_idx], self._trg_seq_ids[sample_idx][:-1],