diff --git a/seq2seq/predict.py b/seq2seq/predict.py index c51eed2d9e0b596de8e07765af634b18ed7f9ee8..c9120bff126cc505b3c0ee3274f65b67e8f78fe6 100644 --- a/seq2seq/predict.py +++ b/seq2seq/predict.py @@ -113,7 +113,7 @@ def do_predict(args): for data in data_loader(): finished_seq = model.test(inputs=flatten(data))[0] finished_seq = finished_seq[:, :, np.newaxis] if len( - finished_seq.shape == 2) else finished_seq + finished_seq.shape) == 2 else finished_seq finished_seq = np.transpose(finished_seq, [0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): diff --git a/seq2seq/reader.py b/seq2seq/reader.py index a6fa73faf24496823ac3dd4db5befad6de032c5b..26f5d6a4d1b9c3e135abdf0984fc38282d7bcaa3 100644 --- a/seq2seq/reader.py +++ b/seq2seq/reader.py @@ -168,6 +168,7 @@ class SampleInfo(object): def __init__(self, i, lens): self.i = i self.lens = lens + self.max_len = lens[0] def get_ranges(self, min_length=None, max_length=None, truncate=False): ranges = [] diff --git a/seq2seq/seq2seq_attn.py b/seq2seq/seq2seq_attn.py index 507c72aa5a39df16936d54ab7d7d474f6b611afc..136b4741d95af90c564e8bac7ce6723198533a28 100644 --- a/seq2seq/seq2seq_attn.py +++ b/seq2seq/seq2seq_attn.py @@ -247,6 +247,8 @@ class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper): self.start_token_value = start_tokens super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens, end_token) + self.end_token = fluid.layers.create_global_var( + shape=[1], dtype="int64", value=end_token, persistable=True) def initialize(self, batch_ref=None): if getattr(self, "need_convert_start_tokens", False): @@ -319,7 +321,7 @@ class AttentionGreedyInferModel(AttentionModel): encoder_padding_mask = (src_mask - 1.0) * 1e9 encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) - # dynamic decoding with beam search + # dynamic decoding with greedy search rs, _ = self.greedy_search_decoder( inits=decoder_initial_states, encoder_output=encoder_output,