diff --git a/transformer/reader.py b/transformer/reader.py index b83617d7179b65d9e2bc335eec1f226423a08fdb..347ef4a2c1044fb4dae4f2921718c252871328eb 100644 --- a/transformer/reader.py +++ b/transformer/reader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,9 @@ import glob import six import os -import tarfile +import io import itertools +from functools import partial import numpy as np import paddle.fluid as fluid @@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.io import BatchSampler, DataLoader, Dataset -def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): +def create_data_loader(args, device): + data_loaders = [None, None] + data_files = [args.training_file, args.validation_file + ] if args.validation_file else [args.training_file] + for i, data_file in enumerate(data_files): + dataset = Seq2SeqDataset( + fpattern=data_file, + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + token_delimiter=args.token_delimiter, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2], + byte_data=True) + args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ + args.unk_idx = dataset.get_vocab_summary() + batch_sampler = Seq2SeqBatchSampler( + dataset=dataset, + use_token_batch=args.use_token_batch, + batch_size=args.batch_size, + pool_size=args.pool_size, + sort_type=args.sort_type, + shuffle=args.shuffle, + shuffle_batch=args.shuffle_batch, + max_length=args.max_length, + distribute_mode=True + if i == 0 else False) # every device eval all data + data_loader = DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + places=device, + collate_fn=partial( + prepare_train_input, + bos_idx=args.bos_idx, + eos_idx=args.eos_idx, + src_pad_idx=args.eos_idx, + trg_pad_idx=args.eos_idx, + n_head=args.n_head), + num_workers=0, # TODO: use multi-process + return_list=True) + data_loaders[i] = data_loader + return data_loaders + + +def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx, + n_head): """ Put all padded data needed by training into a list. """ src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + [inst[0] + [eos_idx] for inst in insts], + src_pad_idx, + n_head, + is_target=False) src_word = src_word.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len) trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( - [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) + [[bos_idx] + inst[1] for inst in insts], + trg_pad_idx, + n_head, + is_target=True) trg_word = trg_word.reshape(-1, trg_max_len) trg_pos = trg_pos.reshape(-1, trg_max_len) @@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): [1, 1, trg_max_len, 1]).astype("float32") lbl_word, lbl_weight, num_token = pad_batch_data( - [inst[2] for inst in insts], + [inst[1] + [eos_idx] for inst in insts], trg_pad_idx, n_head, is_target=False, @@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head): src_word = src_word.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len) - data_inputs = [ - src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias - ] + data_inputs = [src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias] return data_inputs @@ -142,29 +192,30 @@ class SortType(object): class Converter(object): - def __init__(self, vocab, beg, end, unk, delimiter, add_beg): + def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end): self._vocab = vocab self._beg = beg self._end = end self._unk = unk self._delimiter = delimiter self._add_beg = add_beg + self._add_end = add_end def __call__(self, sentence): return ([self._beg] if self._add_beg else []) + [ self._vocab.get(w, self._unk) for w in sentence.split(self._delimiter) - ] + [self._end] + ] + ([self._end] if self._add_end else []) class ComposedConverter(object): def __init__(self, converters): self._converters = converters - def __call__(self, parallel_sentence): + def __call__(self, fields): return [ - self._converters[i](parallel_sentence[i]) - for i in range(len(self._converters)) + converter(field) + for field, converter in zip(fields, self._converters) ] @@ -201,10 +252,11 @@ class TokenBatchCreator(object): class SampleInfo(object): - def __init__(self, i, max_len, min_len): + def __init__(self, i, lens): self.i = i - self.min_len = min_len - self.max_len = max_len + # take bos and eos into account + self.min_len = min(lens[0] + 1, lens[1] + 2) + self.max_len = max(lens[0] + 1, lens[1] + 2) class MinMaxFilter(object): @@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset): src_vocab_fpath, trg_vocab_fpath, fpattern, - tar_fname=None, field_delimiter="\t", token_delimiter=" ", start_mark="", end_mark="", unk_mark="", - only_src=False): - # convert str to bytes, and use byte data - field_delimiter = field_delimiter.encode("utf8") - token_delimiter = token_delimiter.encode("utf8") - start_mark = start_mark.encode("utf8") - end_mark = end_mark.encode("utf8") - unk_mark = unk_mark.encode("utf8") - self._src_vocab = self.load_dict(src_vocab_fpath) - self._trg_vocab = self.load_dict(trg_vocab_fpath) + only_src=False, + trg_fpattern=None, + byte_data=False): + if byte_data: + # The WMT16 bpe data used here seems including bytes can not be + # decoded by utf8. Thus convert str to bytes, and use byte data + field_delimiter = field_delimiter.encode("utf8") + token_delimiter = token_delimiter.encode("utf8") + start_mark = start_mark.encode("utf8") + end_mark = end_mark.encode("utf8") + unk_mark = unk_mark.encode("utf8") + self._byte_data = byte_data + self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data) + self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data) self._bos_idx = self._src_vocab[start_mark] self._eos_idx = self._src_vocab[end_mark] self._unk_idx = self._src_vocab[unk_mark] - self._only_src = only_src self._field_delimiter = field_delimiter self._token_delimiter = token_delimiter - self.load_src_trg_ids(fpattern, tar_fname) - - def load_src_trg_ids(self, fpattern, tar_fname): - converters = [ - Converter(vocab=self._src_vocab, - beg=self._bos_idx, - end=self._eos_idx, - unk=self._unk_idx, - delimiter=self._token_delimiter, - add_beg=False) - ] - if not self._only_src: - converters.append( - Converter(vocab=self._trg_vocab, - beg=self._bos_idx, - end=self._eos_idx, - unk=self._unk_idx, - delimiter=self._token_delimiter, - add_beg=True)) - - converters = ComposedConverter(converters) + self.load_src_trg_ids(fpattern, trg_fpattern) + + def load_src_trg_ids(self, fpattern, trg_fpattern=None): + src_converter = Converter( + vocab=self._src_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False, + add_end=False) + + trg_converter = Converter( + vocab=self._trg_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False, + add_end=False) + + converters = ComposedConverter([src_converter, trg_converter]) self._src_seq_ids = [] - self._trg_seq_ids = None if self._only_src else [] + self._trg_seq_ids = [] self._sample_infos = [] - for i, line in enumerate(self._load_lines(fpattern, tar_fname)): - src_trg_ids = converters(line) - self._src_seq_ids.append(src_trg_ids[0]) - lens = [len(src_trg_ids[0])] - if not self._only_src: - self._trg_seq_ids.append(src_trg_ids[1]) - lens.append(len(src_trg_ids[1])) - self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) + slots = [self._src_seq_ids, self._trg_seq_ids] + for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)): + lens = [] + for field, slot in zip(converters(line), slots): + slot.append(field) + lens.append(len(field)) + self._sample_infos.append(SampleInfo(i, lens)) - def _load_lines(self, fpattern, tar_fname): + def _load_lines(self, fpattern, trg_fpattern=None): fpaths = glob.glob(fpattern) + fpaths = sorted(fpaths) # TODO: Add custum sort assert len(fpaths) > 0, "no matching file to the provided data path" - if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): - if tar_fname is None: - raise Exception("If tar file provided, please set tar_fname.") - - f = tarfile.open(fpaths[0], "rb") - for line in f.extractfile(tar_fname): - fields = line.strip(b"\n").split(self._field_delimiter) - if (not self._only_src - and len(fields) == 2) or (self._only_src - and len(fields) == 1): - yield fields - else: + (f_mode, f_encoding, + endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8", + "\n") + if trg_fpattern is None: for fpath in fpaths: - if not os.path.isfile(fpath): - raise IOError("Invalid file: %s" % fpath) - - with open(fpath, "rb") as f: + with io.open(fpath, f_mode, encoding=f_encoding) as f: for line in f: - fields = line.strip(b"\n").split(self._field_delimiter) - if (not self._only_src and len(fields) == 2) or ( - self._only_src and len(fields) == 1): + fields = line.strip(endl).split(self._field_delimiter) + yield fields + else: + # separated source and target language data files + # assume we can get aligned data by sort the two language files + # TODO: Need more rigorous check + trg_fpaths = glob.glob(trg_fpattern) + trg_fpaths = sorted(trg_fpaths) + assert len(fpaths) == len( + trg_fpaths + ), "the number of source language data files must equal \ + with that of source language" + + for fpath, trg_fpath in zip(fpaths, trg_fpaths): + with io.open(fpath, f_mode, encoding=f_encoding) as f: + with io.open( + trg_fpath, f_mode, encoding=f_encoding) as trg_f: + for line in zip(f, trg_f): + fields = [field.strip(endl) for field in line] yield fields @staticmethod - def load_dict(dict_path, reverse=False): + def load_dict(dict_path, reverse=False, byte_data=False): word_dict = {} - with open(dict_path, "rb") as fdict: + (f_mode, f_encoding, + endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n") + with io.open(dict_path, f_mode, encoding=f_encoding) as fdict: for idx, line in enumerate(fdict): if reverse: - word_dict[idx] = line.strip(b"\n") + word_dict[idx] = line.strip(endl) else: - word_dict[line.strip(b"\n")] = idx + word_dict[line.strip(endl)] = idx return word_dict def get_vocab_summary(self): @@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset): self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx def __getitem__(self, idx): - return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], - self._trg_seq_ids[idx][1:] - ) if not self._only_src else self._src_seq_ids[idx] + return (self._src_seq_ids[idx], self._trg_seq_ids[idx] + ) if self._trg_seq_ids else self._src_seq_ids[idx] def __len__(self): return len(self._sample_infos) @@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler): shuffle_batch=False, use_token_batch=False, clip_last_batch=False, + distribute_mode=True, seed=0): for arg, value in locals().items(): if arg != "self": @@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler): self._random = np.random self._random.seed(seed) # for multi-devices + self._distribute_mode = distribute_mode self._nranks = ParallelEnv().nranks self._local_rank = ParallelEnv().local_rank self._device_id = ParallelEnv().dev_id @@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler): def __iter__(self): # global sort or global shuffle if self._sort_type == SortType.GLOBAL: - infos = sorted(self._dataset._sample_infos, - key=lambda x: x.max_len) + infos = sorted( + self._dataset._sample_infos, key=lambda x: x.max_len) else: if self._shuffle: infos = self._dataset._sample_infos @@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler): batches = [] batch_creator = TokenBatchCreator( - self._batch_size - ) if self._use_token_batch else SentenceBatchCreator(self._batch_size * - self._nranks) + self. + _batch_size) if self._use_token_batch else SentenceBatchCreator( + self._batch_size * self._nranks) batch_creator = MinMaxFilter(self._max_length, self._min_length, batch_creator) @@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler): # for multi-device for batch_id, batch in enumerate(batches): - if batch_id % self._nranks == self._local_rank: + if not self._distribute_mode or ( + batch_id % self._nranks == self._local_rank): batch_indices = [info.i for info in batch] yield batch_indices - if self._local_rank > len(batches) % self._nranks: - yield batch_indices + if self._distribute_mode and len(batches) % self._nranks != 0: + if self._local_rank >= len(batches) % self._nranks: + # use previous data to pad + yield batch_indices def __len__(self): - return 100 + # TODO(guosheng): fix the uncertain length + return 0 diff --git a/transformer/train.py b/transformer/train.py index b2ec12ee22b1d1950ebc80ecbfdfe21fc5b26959..fe4deeabab2004c976932a605fd1e7022eff12a2 100644 --- a/transformer/train.py +++ b/transformer/train.py @@ -17,7 +17,6 @@ import os import six import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from functools import partial import numpy as np import paddle @@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version from model import Input, set_device from callbacks import ProgBarLogger -from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler +from reader import create_data_loader from transformer import Transformer, CrossEntropyCriterion class TrainCallback(ProgBarLogger): - def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): - super(TrainCallback, self).__init__(log_freq, verbose) - # TODO: wrap these override function to simplify + def __init__(self, args, verbose=2): + super(TrainCallback, self).__init__(args.print_step, verbose) + # the best cross-entropy value with label smoothing + loss_normalizer = -( + (1. - args.label_smooth_eps) * np.log( + (1. - args.label_smooth_eps)) + args.label_smooth_eps * + np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) self.loss_normalizer = loss_normalizer def on_train_begin(self, logs=None): @@ -100,42 +103,7 @@ def do_train(args): ] # def dataloader - data_loaders = [None, None] - data_files = [args.training_file, args.validation_file - ] if args.validation_file else [args.training_file] - for i, data_file in enumerate(data_files): - dataset = Seq2SeqDataset( - fpattern=data_file, - src_vocab_fpath=args.src_vocab_fpath, - trg_vocab_fpath=args.trg_vocab_fpath, - token_delimiter=args.token_delimiter, - start_mark=args.special_token[0], - end_mark=args.special_token[1], - unk_mark=args.special_token[2]) - args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ - args.unk_idx = dataset.get_vocab_summary() - batch_sampler = Seq2SeqBatchSampler( - dataset=dataset, - use_token_batch=args.use_token_batch, - batch_size=args.batch_size, - pool_size=args.pool_size, - sort_type=args.sort_type, - shuffle=args.shuffle, - shuffle_batch=args.shuffle_batch, - max_length=args.max_length) - data_loader = DataLoader( - dataset=dataset, - batch_sampler=batch_sampler, - places=device, - collate_fn=partial( - prepare_train_input, - src_pad_idx=args.eos_idx, - trg_pad_idx=args.eos_idx, - n_head=args.n_head), - num_workers=0, # TODO: use multi-process - return_list=True) - data_loaders[i] = data_loader - train_loader, eval_loader = data_loaders + train_loader, eval_loader = create_data_loader(args, device) # define model transformer = Transformer( @@ -166,12 +134,6 @@ def do_train(args): if args.init_from_pretrain_model: transformer.load(args.init_from_pretrain_model, reset_optimizer=True) - # the best cross-entropy value with label smoothing - loss_normalizer = -( - (1. - args.label_smooth_eps) * np.log( - (1. - args.label_smooth_eps)) + args.label_smooth_eps * - np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) - # model train transformer.fit(train_data=train_loader, eval_data=eval_loader, @@ -180,11 +142,7 @@ def do_train(args): save_freq=1, save_dir=args.save_model, verbose=2, - callbacks=[ - TrainCallback( - log_freq=args.print_step, - loss_normalizer=loss_normalizer) - ]) + callbacks=[TrainCallback(args)]) if __name__ == "__main__":