diff --git a/seq2seq/predict.py b/seq2seq/predict.py index fef74cfb22fcfeca26d04d17595df706394432bf..c51eed2d9e0b596de8e07765af634b18ed7f9ee8 100644 --- a/seq2seq/predict.py +++ b/seq2seq/predict.py @@ -28,7 +28,7 @@ from paddle.fluid.io import DataLoader from model import Input, set_device from args import parse_args from seq2seq_base import BaseInferModel -from seq2seq_attn import AttentionInferModel +from seq2seq_attn import AttentionInferModel, AttentionGreedyInferModel from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input @@ -87,7 +87,8 @@ def do_predict(args): num_workers=0, return_list=True) - model_maker = AttentionInferModel if args.attention else BaseInferModel + # model_maker = AttentionInferModel if args.attention else BaseInferModel + model_maker = AttentionGreedyInferModel if args.attention else BaseInferModel model = model_maker( args.src_vocab_size, args.tar_vocab_size, @@ -111,6 +112,8 @@ def do_predict(args): with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: for data in data_loader(): finished_seq = model.test(inputs=flatten(data))[0] + finished_seq = finished_seq[:, :, np.newaxis] if len( + finished_seq.shape == 2) else finished_seq finished_seq = np.transpose(finished_seq, [0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): diff --git a/seq2seq/reader.py b/seq2seq/reader.py index ebdbb47266e2c43b6e1ec862951f0f83bfe5cab0..a6fa73faf24496823ac3dd4db5befad6de032c5b 100644 --- a/seq2seq/reader.py +++ b/seq2seq/reader.py @@ -44,7 +44,8 @@ def create_data_loader(args, device, for_train=True): end_mark="", unk_mark="", max_length=args.max_len if i == 0 else None, - truncate=True) + truncate=True, + trg_add_bos_eos=True) (args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id, unk_id) = dataset.get_vocab_summary() batch_sampler = Seq2SeqBatchSampler( @@ -53,7 +54,8 @@ def create_data_loader(args, device, for_train=True): batch_size=args.batch_size, pool_size=args.batch_size * 20, sort_type=SortType.POOL, - shuffle=False if args.enable_ce else True) + shuffle=False if args.enable_ce else True, + distribute_mode=True if i == 0 else False) data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, @@ -73,7 +75,7 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id): src, src_length = pad_batch_data( [inst[0] for inst in insts], pad_id=pad_id) trg, trg_length = pad_batch_data( - [[bos_id] + inst[1] + [eos_id] for inst in insts], pad_id=pad_id) + [inst[1] for inst in insts], pad_id=pad_id) trg_length = trg_length - 1 return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis] @@ -165,9 +167,24 @@ class TokenBatchCreator(object): class SampleInfo(object): def __init__(self, i, lens): self.i = i - # to be consistent with origianl reader implementation - self.min_len = lens[0] - self.max_len = lens[0] + self.lens = lens + + def get_ranges(self, min_length=None, max_length=None, truncate=False): + ranges = [] + # source + if (min_length is None or self.lens[0] >= min_length) and ( + max_length is None or self.lens[0] <= max_length or truncate): + end = max_length if truncate and max_length else self.lens[0] + ranges.append([0, end]) + # target + if len(self.lens) == 2: + if (min_length is None or self.lens[1] >= min_length) and ( + max_length is None or self.lens[1] <= max_length + 2 or + truncate): + end = max_length + 2 if truncate and max_length else self.lens[ + 1] + ranges.append([0, end]) + return ranges if len(ranges) == len(self.lens) else None class MinMaxFilter(object): @@ -197,6 +214,7 @@ class Seq2SeqDataset(Dataset): end_mark="", unk_mark="", trg_fpattern=None, + trg_add_bos_eos=False, byte_data=False, min_length=None, max_length=None, @@ -220,6 +238,7 @@ class Seq2SeqDataset(Dataset): self._min_length = min_length self._max_length = max_length self._truncate = truncate + self._trg_add_bos_eos = trg_add_bos_eos self.load_src_trg_ids(fpattern, trg_fpattern) def load_src_trg_ids(self, fpattern, trg_fpattern=None): @@ -238,8 +257,8 @@ class Seq2SeqDataset(Dataset): end=self._eos_idx, unk=self._unk_idx, delimiter=self._token_delimiter, - add_beg=False, - add_end=False) + add_beg=True if self._trg_add_bos_eos else False, + add_end=True if self._trg_add_bos_eos else False) converters = ComposedConverter([src_converter, trg_converter]) @@ -252,13 +271,12 @@ class Seq2SeqDataset(Dataset): fields = converters(line) lens = [len(field) for field in fields] sample = SampleInfo(i, lens) - if (self._min_length is None or - sample.min_len >= self._min_length) and ( - self._max_length is None or - sample.max_len <= self._max_length or self._truncate): - for field, slot in zip(fields, slots): - slot.append(field[:self._max_length] if self._truncate and - self._max_length is not None else field) + field_ranges = sample.get_ranges(self._min_length, + self._max_length, self._truncate) + if field_ranges: + for field, field_range, slot in zip(fields, field_ranges, + slots): + slot.append(field[field_range[0]:field_range[1]]) self._sample_infos.append(sample) def _load_lines(self, fpattern, trg_fpattern=None): diff --git a/seq2seq/seq2seq_attn.py b/seq2seq/seq2seq_attn.py index c5baee74e832816d8a53d7a5a5e854d40b9d47cb..507c72aa5a39df16936d54ab7d7d474f6b611afc 100644 --- a/seq2seq/seq2seq_attn.py +++ b/seq2seq/seq2seq_attn.py @@ -152,7 +152,7 @@ class AttentionModel(Model): self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, num_layers, dropout_prob, init_scale) - def forward(self, src, src_length, trg, trg_length): + def forward(self, src, src_length, trg): # encoder encoder_output, encoder_final_state = self.encoder(src, src_length) @@ -174,11 +174,7 @@ class AttentionModel(Model): # decoder with attentioon predict = self.decoder(trg, decoder_initial_states, encoder_output, encoder_padding_mask) - - # for target padding mask - mask = layers.sequence_mask( - trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype) - return predict, mask + return predict class AttentionInferModel(AttentionModel): @@ -242,3 +238,90 @@ class AttentionInferModel(AttentionModel): encoder_output=encoder_output, encoder_padding_mask=encoder_padding_mask) return rs + + +class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper): + def __init__(self, embedding_fn, start_tokens, end_token): + if isinstance(start_tokens, int): + self.need_convert_start_tokens = True + self.start_token_value = start_tokens + super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens, + end_token) + + def initialize(self, batch_ref=None): + if getattr(self, "need_convert_start_tokens", False): + assert batch_ref is not None, ( + "Need to give batch_ref to get batch size " + "to initialize the tensor for start tokens.") + self.start_tokens = fluid.layers.fill_constant_batch_size_like( + input=fluid.layers.utils.flatten(batch_ref)[0], + shape=[-1], + dtype="int64", + value=self.start_token_value, + input_dim_idx=0) + return super(GreedyEmbeddingHelper, self).initialize() + + +class BasicDecoder(fluid.layers.BasicDecoder): + def initialize(self, initial_cell_states): + (initial_inputs, + initial_finished) = self.helper.initialize(initial_cell_states) + return initial_inputs, initial_cell_states, initial_finished + + +class AttentionGreedyInferModel(AttentionModel): + def __init__(self, + src_vocab_size, + trg_vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0., + bos_id=0, + eos_id=1, + beam_size=1, + max_out_len=256): + args = dict(locals()) + args.pop("self") + args.pop("__class__", None) # py3 + args.pop("beam_size", None) + self.bos_id = args.pop("bos_id") + self.eos_id = args.pop("eos_id") + self.max_out_len = args.pop("max_out_len") + super(AttentionGreedyInferModel, self).__init__(**args) + # dynamic decoder for inference + decoder_helper = GreedyEmbeddingHelper( + start_tokens=bos_id, + end_token=eos_id, + embedding_fn=self.decoder.embedder) + decoder = BasicDecoder( + cell=self.decoder.lstm_attention.cell, + helper=decoder_helper, + output_fn=self.decoder.output_layer) + self.greedy_search_decoder = DynamicDecode( + decoder, max_step_num=max_out_len, is_test=True) + + def forward(self, src, src_length): + # encoding + encoder_output, encoder_final_state = self.encoder(src, src_length) + + # decoder initial states + decoder_initial_states = [ + encoder_final_state, + self.decoder.lstm_attention.cell.get_initial_states( + batch_ref=encoder_output, shape=[self.hidden_size]) + ] + # attention mask to avoid paying attention on padddings + src_mask = layers.sequence_mask( + src_length, + maxlen=layers.shape(src)[1], + dtype=encoder_output.dtype) + encoder_padding_mask = (src_mask - 1.0) * 1e9 + encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) + + # dynamic decoding with beam search + rs, _ = self.greedy_search_decoder( + inits=decoder_initial_states, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + return rs.sample_ids diff --git a/seq2seq/seq2seq_base.py b/seq2seq/seq2seq_base.py index b37b871b03f8f75f48ae11c3f397a9e45763fd98..2cfd8eaa71e681b07c36017c8c8c29c2968af872 100644 --- a/seq2seq/seq2seq_base.py +++ b/seq2seq/seq2seq_base.py @@ -27,7 +27,10 @@ class CrossEntropyCriterion(Loss): super(CrossEntropyCriterion, self).__init__() def forward(self, outputs, labels): - (predict, mask), label = outputs, labels[0] + predict, (trg_length, label) = outputs[0], labels + # for target padding mask + mask = layers.sequence_mask( + trg_length, maxlen=layers.shape(predict)[1], dtype=predict.dtype) cost = layers.softmax_with_cross_entropy( logits=predict, label=label, soft_label=False) @@ -151,17 +154,13 @@ class BaseModel(Model): self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, num_layers, dropout_prob, init_scale) - def forward(self, src, src_length, trg, trg_length): + def forward(self, src, src_length, trg): # encoder encoder_output, encoder_final_states = self.encoder(src, src_length) # decoder predict = self.decoder(trg, encoder_final_states) - - # for target padding mask - mask = layers.sequence_mask( - trg_length, maxlen=layers.shape(trg)[1], dtype=predict.dtype) - return predict, mask + return predict class BaseInferModel(BaseModel): diff --git a/seq2seq/train.py b/seq2seq/train.py index 9e809b0000052ef8e482c8a0acf8ca95f955e880..a1cd45477c05eec1e492640c20ddbd36198efa8d 100644 --- a/seq2seq/train.py +++ b/seq2seq/train.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from paddle.fluid.io import DataLoader from model import Input, set_device +from metrics import Metric from callbacks import ProgBarLogger from args import parse_args from seq2seq_base import BaseModel, CrossEntropyCriterion @@ -31,6 +32,65 @@ from seq2seq_attn import AttentionModel from reader import create_data_loader +class TrainCallback(ProgBarLogger): + def __init__(self, args, ppl, verbose=2): + super(TrainCallback, self).__init__(1, verbose) + # control metric + self.ppl = ppl + self.batch_size = args.batch_size + + def on_train_begin(self, logs=None): + super(TrainCallback, self).on_train_begin(logs) + self.train_metrics += ["ppl"] # remove loss to not print it + self.ppl.reset() + + def on_train_batch_end(self, step, logs=None): + batch_loss = logs["loss"][0] + self.ppl.total_loss += batch_loss * self.batch_size + logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count) + if step > 0 and step % self.ppl.reset_freq == 0: + self.ppl.reset() + super(TrainCallback, self).on_train_batch_end(step, logs) + + def on_eval_begin(self, logs=None): + super(TrainCallback, self).on_eval_begin(logs) + self.eval_metrics = ["ppl"] + self.ppl.reset() + + def on_eval_batch_end(self, step, logs=None): + batch_loss = logs["loss"][0] + self.ppl.total_loss += batch_loss * self.batch_size + logs["ppl"] = np.exp(self.ppl.total_loss / self.ppl.word_count) + super(TrainCallback, self).on_eval_batch_end(step, logs) + + +class PPL(Metric): + def __init__(self, reset_freq=100, name=None): + super(PPL, self).__init__() + self._name = name or "ppl" + self.reset_freq = reset_freq + self.reset() + + def add_metric_op(self, pred, label): + seq_length = label[0] + word_num = fluid.layers.reduce_sum(seq_length) + return word_num + + def update(self, word_num): + self.word_count += word_num + return word_num + + def reset(self): + self.total_loss = 0 + self.word_count = 0 + + def accumulate(self): + return self.word_count + + def name(self): + return self._name + + def do_train(args): device = set_device("gpu" if args.use_gpu else "cpu") fluid.enable_dygraph(device) if args.eager_run else None @@ -47,10 +107,13 @@ def do_train(args): [None], "int64", name="src_length"), Input( [None, None], "int64", name="trg_word"), + ] + labels = [ Input( [None], "int64", name="trg_length"), + Input( + [None, None, 1], "int64", name="label"), ] - labels = [Input([None, None, 1], "int64", name="label"), ] # def dataloader train_loader, eval_loader = create_data_loader(args, device) @@ -63,15 +126,20 @@ def do_train(args): learning_rate=args.learning_rate, parameter_list=model.parameters()) optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm( clip_norm=args.max_grad_norm) + ppl_metric = PPL() model.prepare( - optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels) + optimizer, + CrossEntropyCriterion(), + ppl_metric, + inputs=inputs, + labels=labels) model.fit(train_data=train_loader, eval_data=eval_loader, epochs=args.max_epoch, eval_freq=1, save_freq=1, save_dir=args.model_path, - log_freq=1) + callbacks=[TrainCallback(args, ppl_metric)]) if __name__ == "__main__":