From 64a7e1a0c8ce2b0e2ee29e96ae4b711cb6ccae6a Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 24 Mar 2020 21:05:53 +0800 Subject: [PATCH] Update transformer --- transformer/predict.py | 19 +++--- transformer/reader.py | 58 +++++++++++------- transformer/train.py | 134 +++++++++++++++++++---------------------- 3 files changed, 107 insertions(+), 104 deletions(-) diff --git a/transformer/predict.py b/transformer/predict.py index 136c8c6..7918065 100644 --- a/transformer/predict.py +++ b/transformer/predict.py @@ -88,14 +88,17 @@ def do_predict(args): # define model inputs = [ Input( - [None, None], "int64", name="src_word"), Input( - [None, None], "int64", name="src_pos"), Input( - [None, args.n_head, None, None], - "float32", - name="src_slf_attn_bias"), Input( - [None, args.n_head, None, None], - "float32", - name="trg_src_attn_bias") + [None, None], "int64", name="src_word"), + Input( + [None, None], "int64", name="src_pos"), + Input( + [None, args.n_head, None, None], + "float32", + name="src_slf_attn_bias"), + Input( + [None, args.n_head, None, None], + "float32", + name="trg_src_attn_bias"), ] transformer = InferTransformer( args.src_vocab_size, diff --git a/transformer/reader.py b/transformer/reader.py index 5e1a5dd..f182ba2 100644 --- a/transformer/reader.py +++ b/transformer/reader.py @@ -19,6 +19,15 @@ import tarfile import numpy as np import paddle.fluid as fluid +from paddle.fluid.io import BatchSampler, DataLoader + + +class TokenBatchSampler(BatchSampler): + def __init__(self): + pass + + def __iter(self): + pass def pad_batch_data(insts, @@ -54,7 +63,8 @@ def pad_batch_data(insts, if is_target: # This is used to avoid attention on paddings and subsequent # words. - slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len)) + slf_attn_bias_data = np.ones( + (inst_data.shape[0], max_len, max_len)) slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape([-1, 1, max_len, max_len]) slf_attn_bias_data = np.tile(slf_attn_bias_data, @@ -306,6 +316,7 @@ class DataProcessor(object): :param seed: The seed for random. :type seed: int """ + def __init__(self, src_vocab_fpath, trg_vocab_fpath, @@ -360,21 +371,23 @@ class DataProcessor(object): def load_src_trg_ids(self, fpattern, tar_fname): converters = [ - Converter(vocab=self._src_vocab, - beg=self._bos_idx, - end=self._eos_idx, - unk=self._unk_idx, - delimiter=self._token_delimiter, - add_beg=False) + Converter( + vocab=self._src_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False) ] if not self._only_src: converters.append( - Converter(vocab=self._trg_vocab, - beg=self._bos_idx, - end=self._eos_idx, - unk=self._unk_idx, - delimiter=self._token_delimiter, - add_beg=True)) + Converter( + vocab=self._trg_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=True)) converters = ComposedConverter(converters) @@ -402,9 +415,8 @@ class DataProcessor(object): f = tarfile.open(fpaths[0], "rb") for line in f.extractfile(tar_fname): fields = line.strip(b"\n").split(self._field_delimiter) - if (not self._only_src - and len(fields) == 2) or (self._only_src - and len(fields) == 1): + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): yield fields else: for fpath in fpaths: @@ -414,9 +426,8 @@ class DataProcessor(object): with open(fpath, "rb") as f: for line in f: fields = line.strip(b"\n").split(self._field_delimiter) - if (not self._only_src - and len(fields) == 2) or (self._only_src - and len(fields) == 1): + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): yield fields @staticmethod @@ -477,7 +488,8 @@ class DataProcessor(object): if self._only_src: yield [[self._src_seq_ids[idx]] for idx in batch_ids] else: - yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], + yield [(self._src_seq_ids[idx], + self._trg_seq_ids[idx][:-1], self._trg_seq_ids[idx][1:]) for idx in batch_ids] return __impl__ @@ -512,8 +524,8 @@ class DataProcessor(object): for item in data_reader(): inst_num_per_part = len(item) // count for i in range(count): - yield item[inst_num_per_part * i:inst_num_per_part * - (i + 1)] + yield item[inst_num_per_part * i:inst_num_per_part * (i + 1 + )] return __impl__ @@ -535,7 +547,7 @@ class DataProcessor(object): for data in data_reader(): data_inputs = prepare_train_input(data, src_pad_idx, trg_pad_idx, n_head) - yield data_inputs + yield data_inputs[:-2], data_inputs[-2:] def __for_predict__(): for data in data_reader(): diff --git a/transformer/train.py b/transformer/train.py index 148b26e..e27a013 100644 --- a/transformer/train.py +++ b/transformer/train.py @@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version import reader from transformer import Transformer, CrossEntropyCriterion, NoamDecay from model import Input +from callbacks import ProgBarLogger + + +class LoggerCallback(ProgBarLogger): + def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): + super(LoggerCallback, self).__init__(log_freq, verbose) + self.loss_normalizer = loss_normalizer + + def on_train_begin(self, logs=None): + super(LoggerCallback, self).on_train_begin(logs) + self.train_metrics += ["normalized loss", "ppl"] + + def on_train_batch_end(self, step, logs=None): + logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer + logs["ppl"] = np.exp(min(logs["loss"][0], 100)) + super(LoggerCallback, self).on_train_batch_end(step, logs) + + def on_eval_begin(self, logs=None): + super(LoggerCallback, self).on_eval_begin(logs) + self.eval_metrics += ["normalized loss", "ppl"] + + def on_eval_batch_end(self, step, logs=None): + logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer + logs["ppl"] = np.exp(min(logs["loss"][0], 100)) + super(LoggerCallback, self).on_eval_batch_end(step, logs) def do_train(args): + init_context('dynamic' if FLAGS.dynamic else 'static') trainer_count = 1 #get_nranks() @contextlib.contextmanager @@ -102,24 +128,31 @@ def do_train(args): # define model inputs = [ Input( - [None, None], "int64", name="src_word"), Input( - [None, None], "int64", name="src_pos"), Input( - [None, args.n_head, None, None], - "float32", - name="src_slf_attn_bias"), Input( - [None, None], "int64", name="trg_word"), Input( - [None, None], "int64", name="trg_pos"), Input( - [None, args.n_head, None, None], - "float32", - name="trg_slf_attn_bias"), Input( - [None, args.n_head, None, None], - "float32", - name="trg_src_attn_bias") + [None, None], "int64", name="src_word"), + Input( + [None, None], "int64", name="src_pos"), + Input( + [None, args.n_head, None, None], + "float32", + name="src_slf_attn_bias"), + Input( + [None, None], "int64", name="trg_word"), + Input( + [None, None], "int64", name="trg_pos"), + Input( + [None, args.n_head, None, None], + "float32", + name="trg_slf_attn_bias"), + Input( + [None, args.n_head, None, None], + "float32", + name="trg_src_attn_bias"), ] labels = [ Input( - [None, 1], "int64", name="label"), Input( - [None, 1], "float32", name="weight") + [None, 1], "int64", name="label"), + Input( + [None, 1], "float32", name="weight"), ] transformer = Transformer( @@ -149,7 +182,8 @@ def do_train(args): ## init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: transformer.load( - os.path.join(args.init_from_pretrain_model, "transformer")) + os.path.join(args.init_from_pretrain_model, "transformer"), + reset_optimizer=True) # the best cross-entropy value with label smoothing loss_normalizer = -( @@ -157,63 +191,17 @@ def do_train(args): (1. - args.label_smooth_eps)) + args.label_smooth_eps * np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) - step_idx = 0 - # train loop - for pass_id in range(args.epoch): - pass_start_time = time.time() - batch_id = 0 - for input_data in train_loader(): - losses = transformer.train(input_data[:-2], input_data[-2:]) - - if step_idx % args.print_step == 0: - total_avg_cost = np.sum(losses) - - if step_idx == 0: - logging.info( - "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " - "normalized loss: %f, ppl: %f" % - (step_idx, pass_id, batch_id, total_avg_cost, - total_avg_cost - loss_normalizer, - np.exp([min(total_avg_cost, 100)]))) - avg_batch_time = time.time() - else: - logging.info( - "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " - "normalized loss: %f, ppl: %f, speed: %.2f step/s" - % - (step_idx, pass_id, batch_id, total_avg_cost, - total_avg_cost - loss_normalizer, - np.exp([min(total_avg_cost, 100)]), - args.print_step / (time.time() - avg_batch_time))) - avg_batch_time = time.time() - - if step_idx % args.save_step == 0 and step_idx != 0: - # validation: how to accumulate with Model loss - if args.validation_file: - total_avg_cost = 0 - for idx, input_data in enumerate(val_loader()): - losses = transformer.eval(input_data[:-2], - input_data[-2:]) - total_avg_cost += np.sum(losses) - total_avg_cost /= idx + 1 - logging.info("validation, step_idx: %d, avg loss: %f, " - "normalized loss: %f, ppl: %f" % - (step_idx, total_avg_cost, - total_avg_cost - loss_normalizer, - np.exp([min(total_avg_cost, 100)]))) - - transformer.save( - os.path.join(args.save_model, "step_" + str(step_idx), - "transformer")) - - batch_id += 1 - step_idx += 1 - - time_consumed = time.time() - pass_start_time - - if args.save_model: - transformer.save( - os.path.join(args.save_model, "step_final", "transformer")) + transformer.fit(train_loader=train_loader, + eval_loader=val_loader, + epochs=1, + eval_freq=1, + save_freq=1, + verbose=2, + callbacks=[ + LoggerCallback( + log_freq=args.print_step, + loss_normalizer=loss_normalizer) + ]) if __name__ == "__main__": -- GitLab