diff --git a/transformer/reader.py b/transformer/reader.py index f182ba2a9ce0bbff01d85efff61052ac78ca09c9..b31d5df0e086d1b21a65518e46c1cf853a44e374 100644 --- a/transformer/reader.py +++ b/transformer/reader.py @@ -16,18 +16,68 @@ import glob import six import os import tarfile +import itertools import numpy as np import paddle.fluid as fluid -from paddle.fluid.io import BatchSampler, DataLoader +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import BatchSampler, DataLoader, Dataset -class TokenBatchSampler(BatchSampler): - def __init__(self): - pass +def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): + """ + Put all padded data needed by training into a list. + """ + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + src_word = src_word.reshape(-1, src_max_len) + src_pos = src_pos.reshape(-1, src_max_len) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) + trg_word = trg_word.reshape(-1, trg_max_len) + trg_pos = trg_pos.reshape(-1, trg_max_len) - def __iter(self): - pass + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, trg_max_len, 1]).astype("float32") + + lbl_word, lbl_weight, num_token = pad_batch_data( + [inst[2] for inst in insts], + trg_pad_idx, + n_head, + is_target=False, + is_label=True, + return_attn_bias=False, + return_max_len=False, + return_num_token=True) + lbl_word = lbl_word.reshape(-1, 1) + lbl_weight = lbl_weight.reshape(-1, 1) + + data_inputs = [ + src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + ] + + return data_inputs + + +def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head): + """ + Put all padded data needed by beam search decoder into a list. + """ + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + # start tokens + trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, 1, 1]).astype("float32") + trg_word = trg_word.reshape(-1, 1) + src_word = src_word.reshape(-1, src_max_len) + src_pos = src_pos.reshape(-1, src_max_len) + + data_inputs = [ + src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias + ] + return data_inputs def pad_batch_data(insts, @@ -88,60 +138,206 @@ def pad_batch_data(insts, return return_list if len(return_list) > 1 else return_list[0] -def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): - """ - Put all padded data needed by training into a list. - """ - src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) - src_word = src_word.reshape(-1, src_max_len) - src_pos = src_pos.reshape(-1, src_max_len) - trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( - [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) - trg_word = trg_word.reshape(-1, trg_max_len) - trg_pos = trg_pos.reshape(-1, trg_max_len) +class Seq2SeqDataset(Dataset): + def __init__(self, + src_vocab_fpath, + trg_vocab_fpath, + fpattern, + tar_fname=None, + field_delimiter="\t", + token_delimiter=" ", + start_mark="", + end_mark="", + unk_mark="", + only_src=False): + # convert str to bytes, and use byte data + field_delimiter = field_delimiter.encode("utf8") + token_delimiter = token_delimiter.encode("utf8") + start_mark = start_mark.encode("utf8") + end_mark = end_mark.encode("utf8") + unk_mark = unk_mark.encode("utf8") + self._src_vocab = self.load_dict(src_vocab_fpath) + self._trg_vocab = self.load_dict(trg_vocab_fpath) + self._bos_idx = self._src_vocab[start_mark] + self._eos_idx = self._src_vocab[end_mark] + self._unk_idx = self._src_vocab[unk_mark] + self._only_src = only_src + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter + self.load_src_trg_ids(fpattern, tar_fname) - trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], - [1, 1, trg_max_len, 1]).astype("float32") + def load_src_trg_ids(self, fpattern, tar_fname): + converters = [ + Converter( + vocab=self._src_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False) + ] + if not self._only_src: + converters.append( + Converter( + vocab=self._trg_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=True)) - lbl_word, lbl_weight, num_token = pad_batch_data( - [inst[2] for inst in insts], - trg_pad_idx, - n_head, - is_target=False, - is_label=True, - return_attn_bias=False, - return_max_len=False, - return_num_token=True) - lbl_word = lbl_word.reshape(-1, 1) - lbl_weight = lbl_weight.reshape(-1, 1) + converters = ComposedConverter(converters) - data_inputs = [ - src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight - ] + self._src_seq_ids = [] + self._trg_seq_ids = None if self._only_src else [] + self._sample_infos = [] - return data_inputs + for i, line in enumerate(self._load_lines(fpattern, tar_fname)): + src_trg_ids = converters(line) + self._src_seq_ids.append(src_trg_ids[0]) + lens = [len(src_trg_ids[0])] + if not self._only_src: + self._trg_seq_ids.append(src_trg_ids[1]) + lens.append(len(src_trg_ids[1])) + self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) + def _load_lines(self, fpattern, tar_fname): + fpaths = glob.glob(fpattern) + assert len(fpaths) > 0, "no matching file to the provided data path" -def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head): - """ - Put all padded data needed by beam search decoder into a list. - """ - src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) - # start tokens - trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") - trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], - [1, 1, 1, 1]).astype("float32") - trg_word = trg_word.reshape(-1, 1) - src_word = src_word.reshape(-1, src_max_len) - src_pos = src_pos.reshape(-1, src_max_len) + if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): + if tar_fname is None: + raise Exception("If tar file provided, please set tar_fname.") - data_inputs = [ - src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias - ] - return data_inputs + f = tarfile.open(fpaths[0], "rb") + for line in f.extractfile(tar_fname): + fields = line.strip(b"\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields + else: + for fpath in fpaths: + if not os.path.isfile(fpath): + raise IOError("Invalid file: %s" % fpath) + + with open(fpath, "rb") as f: + for line in f: + fields = line.strip(b"\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields + + @staticmethod + def load_dict(dict_path, reverse=False): + word_dict = {} + with open(dict_path, "rb") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip(b"\n") + else: + word_dict[line.strip(b"\n")] = idx + return word_dict + + def get_vocab_summary(self): + return len(self._src_vocab), len( + self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx + + def __getitem__(self, idx): + return (self._src_seq_ids[idx], self._trg_seq_ids[idx] + ) if not self._only_src else self._src_seq_ids[idx] + + def __len__(self): + return len(self._sample_infos) + + +class Seq2SeqBatchSampler(BatchSampler): + def __init__(self, + dataset, + batch_size, + pool_size, + sort_type=SortType.GLOBAL, + min_length=0, + max_length=100, + shuffle=True, + shuffle_batch=False, + use_token_batch=False, + clip_last_batch=False, + seed=0): + for arg, value in locals().items(): + if arg != "self": + setattr(self, "_" + arg, value) + self._random = np.random + self._random.seed(seed) + # for multi-devices + self._nranks = ParallelEnv().nranks + self._local_rank = ParallelEnv().local_rank + self._device_id = ParallelEnv().dev_id + + def __iter__(self): + # global sort or global shuffle + if self._sort_type == SortType.GLOBAL: + infos = sorted(self.dataset._sample_infos, key=lambda x: x.max_len) + else: + if self._shuffle: + infos = self.dataset._sample_infos + self._random.shuffle(infos) + else: + infos = self.dataset._sample_infos + + if self._sort_type == SortType.POOL: + reverse = True + for i in range(0, len(infos), self._pool_size): + # to avoid placing short next to long sentences + reverse = not reverse + infos[i:i + self._pool_size] = sorted( + infos[i:i + self._pool_size], + key=lambda x: x.max_len, + reverse=reverse) + + batches = [] + batch_creator = TokenBatchCreator( + self. + _batch_size) if self._use_token_batch else SentenceBatchCreator( + self._batch_size * self._nranks) + batch_creator = MinMaxFilter(self._max_length, self._min_length, + batch_creator) + + for info in infos: + batch = batch_creator.append(info) + if batch is not None: + batches.append(batch) + + if not self._clip_last_batch and len(batch_creator.batch) != 0: + batches.append(batch_creator.batch) + + if self._shuffle_batch: + self._random.shuffle(batches) + + if not self._use_token_batch: + # when producing batches according to sequence number, to confirm + # neighbor batches which would be feed and run parallel have similar + # length (thus similar computational cost) after shuffle, we as take + # them as a whole when shuffling and split here + batches = [[ + batch[self._batch_size * i:self._batch_size * (i + 1)] + for i in range(self._nranks) + ] for batch in batches] + batches = itertools.chain.from_iterable(batches) + + # for multi-device + for batch_id, batch in enumerate(batches): + if batch_id % self._nranks == self._local_rank: + batch_indices = [info.i for info in batch] + yield batch_indices + if self._local_rank > len(batches) % self._nranks: + yield batch_indices + + def __len__(self): + pass + + @property + def dev_id(self): + return self._dev_id class SortType(object): diff --git a/transformer/train.py b/transformer/train.py index e27a013ad08aa5c3ae28f854040695c1181ef13d..794edcfde85ab62bca2ed211173ba8aa17181dc4 100644 --- a/transformer/train.py +++ b/transformer/train.py @@ -24,6 +24,7 @@ import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable +from paddle.fluid.io import DataLoader from utils.configure import PDConfig from utils.check import check_gpu, check_version @@ -38,6 +39,7 @@ from callbacks import ProgBarLogger class LoggerCallback(ProgBarLogger): def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): super(LoggerCallback, self).__init__(log_freq, verbose) + # TODO: wrap these override function to simplify self.loss_normalizer = loss_normalizer def on_train_begin(self, logs=None): @@ -60,148 +62,111 @@ class LoggerCallback(ProgBarLogger): def do_train(args): - init_context('dynamic' if FLAGS.dynamic else 'static') - trainer_count = 1 #get_nranks() - - @contextlib.contextmanager - def null_guard(): - yield - - guard = fluid.dygraph.guard() if args.eager_run else null_guard() - - # define the data generator - processor = reader.DataProcessor( + # init_context('dynamic' if FLAGS.dynamic else 'static') + + # set seed for CE + random_seed = eval(str(args.random_seed)) + if random_seed is not None: + fluid.default_main_program().random_seed = random_seed + fluid.default_startup_program().random_seed = random_seed + + # define model + inputs = [ + Input( + [None, None], "int64", name="src_word"), Input( + [None, None], "int64", name="src_pos"), Input( + [None, args.n_head, None, None], + "float32", + name="src_slf_attn_bias"), Input( + [None, None], "int64", name="trg_word"), Input( + [None, None], "int64", name="trg_pos"), Input( + [None, args.n_head, None, None], + "float32", + name="trg_slf_attn_bias"), Input( + [None, args.n_head, None, None], + "float32", + name="trg_src_attn_bias") + ] + labels = [ + Input( + [None, 1], "int64", name="label"), + Input( + [None, 1], "float32", name="weight"), + ] + + dataset = reader.Seq2SeqDataset( fpattern=args.training_file, src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, token_delimiter=args.token_delimiter, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2]) + args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ + args.unk_idx = dataset.get_vocab_summary() + batch_sampler = reader.Seq2SeqBatchSampler( + dataset=dataset, use_token_batch=args.use_token_batch, batch_size=args.batch_size, - device_count=trainer_count, pool_size=args.pool_size, sort_type=args.sort_type, shuffle=args.shuffle, shuffle_batch=args.shuffle_batch, - start_mark=args.special_token[0], - end_mark=args.special_token[1], - unk_mark=args.special_token[2], - max_length=args.max_length, - n_head=args.n_head) - batch_generator = processor.data_generator(phase="train") - if trainer_count > 1: # for multi-process gpu training - batch_generator = fluid.contrib.reader.distributed_batch_reader( - batch_generator) - if args.validation_file: - val_processor = reader.DataProcessor( - fpattern=args.validation_file, - src_vocab_fpath=args.src_vocab_fpath, - trg_vocab_fpath=args.trg_vocab_fpath, - token_delimiter=args.token_delimiter, - use_token_batch=args.use_token_batch, - batch_size=args.batch_size, - device_count=trainer_count, - pool_size=args.pool_size, - sort_type=args.sort_type, - shuffle=False, - shuffle_batch=False, - start_mark=args.special_token[0], - end_mark=args.special_token[1], - unk_mark=args.special_token[2], - max_length=args.max_length, - n_head=args.n_head) - val_batch_generator = val_processor.data_generator(phase="train") - args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ - args.unk_idx = processor.get_vocab_summary() - - with guard: - # set seed for CE - random_seed = eval(str(args.random_seed)) - if random_seed is not None: - fluid.default_main_program().random_seed = random_seed - fluid.default_startup_program().random_seed = random_seed - - # define data loader - train_loader = batch_generator - if args.validation_file: - val_loader = val_batch_generator - - # define model - inputs = [ - Input( - [None, None], "int64", name="src_word"), - Input( - [None, None], "int64", name="src_pos"), - Input( - [None, args.n_head, None, None], - "float32", - name="src_slf_attn_bias"), - Input( - [None, None], "int64", name="trg_word"), - Input( - [None, None], "int64", name="trg_pos"), - Input( - [None, args.n_head, None, None], - "float32", - name="trg_slf_attn_bias"), - Input( - [None, args.n_head, None, None], - "float32", - name="trg_src_attn_bias"), - ] - labels = [ - Input( - [None, 1], "int64", name="label"), - Input( - [None, 1], "float32", name="weight"), - ] - - transformer = Transformer( - args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, - args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, - args.d_inner_hid, args.prepostprocess_dropout, - args.attention_dropout, args.relu_dropout, args.preprocess_cmd, - args.postprocess_cmd, args.weight_sharing, args.bos_idx, - args.eos_idx) - - transformer.prepare( - fluid.optimizer.Adam( - learning_rate=fluid.layers.noam_decay( - args.d_model, args.warmup_steps), # args.learning_rate), - beta1=args.beta1, - beta2=args.beta2, - epsilon=float(args.eps), - parameter_list=transformer.parameters()), - CrossEntropyCriterion(args.label_smooth_eps), - inputs=inputs, - labels=labels) - - ## init from some checkpoint, to resume the previous training - if args.init_from_checkpoint: - transformer.load( - os.path.join(args.init_from_checkpoint, "transformer")) - ## init from some pretrain models, to better solve the current task - if args.init_from_pretrain_model: - transformer.load( - os.path.join(args.init_from_pretrain_model, "transformer"), - reset_optimizer=True) - - # the best cross-entropy value with label smoothing - loss_normalizer = -( - (1. - args.label_smooth_eps) * np.log( - (1. - args.label_smooth_eps)) + args.label_smooth_eps * - np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) - - transformer.fit(train_loader=train_loader, - eval_loader=val_loader, - epochs=1, - eval_freq=1, - save_freq=1, - verbose=2, - callbacks=[ - LoggerCallback( - log_freq=args.print_step, - loss_normalizer=loss_normalizer) - ]) + max_length=args.max_length) + train_loader = DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + places=None, + feed_list=[x.forward() for x in inputs + labels], + num_workers=0, + return_list=True) + + transformer = Transformer( + args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, + args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, + args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout, + args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd, + args.weight_sharing, args.bos_idx, args.eos_idx) + + transformer.prepare( + fluid.optimizer.Adam( + learning_rate=fluid.layers.noam_decay( + args.d_model, args.warmup_steps), # args.learning_rate), + beta1=args.beta1, + beta2=args.beta2, + epsilon=float(args.eps), + parameter_list=transformer.parameters()), + CrossEntropyCriterion(args.label_smooth_eps), + inputs=inputs, + labels=labels) + + ## init from some checkpoint, to resume the previous training + if args.init_from_checkpoint: + transformer.load( + os.path.join(args.init_from_checkpoint, "transformer")) + ## init from some pretrain models, to better solve the current task + if args.init_from_pretrain_model: + transformer.load( + os.path.join(args.init_from_pretrain_model, "transformer"), + reset_optimizer=True) + + # the best cross-entropy value with label smoothing + loss_normalizer = -( + (1. - args.label_smooth_eps) * np.log( + (1. - args.label_smooth_eps)) + args.label_smooth_eps * + np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) + + transformer.fit(train_loader=train_loader, + eval_loader=None, + epochs=1, + eval_freq=1, + save_freq=1, + verbose=2, + callbacks=[ + LoggerCallback( + log_freq=args.print_step, + loss_normalizer=loss_normalizer) + ]) if __name__ == "__main__":