diff --git a/PaddleNLP/paddlenlp/data/sampler.py b/PaddleNLP/paddlenlp/data/sampler.py index f811608553cc4f0e8aaa8e4c1a7af9d62e0f2a70..b9825e73803bad75f6c9c5c58fed099235a3b4d7 100644 --- a/PaddleNLP/paddlenlp/data/sampler.py +++ b/PaddleNLP/paddlenlp/data/sampler.py @@ -170,24 +170,32 @@ class SamplerHelper(object): return type(self)(self.data_source, _impl) - def batch(self, - batch_size, - drop_last=False, - batch_size_fn=None, - batch_fn=None): + def batch(self, batch_size, drop_last=False, batch_size_fn=None, key=None): """ To produce a BatchSampler. - Agrs: + Args: batch_size (int): Batch size. - drop_last (bool): Whether to drop the last mini batch. Default: False. - batch_size_fn (callable, optional): Return the size of mini batch so far. Default: None. - batch_fn (callable, optional): Transformations to be performed. Default: None. + drop_last (bool): Whether to drop the last mini batch. Default: + False. + batch_size_fn (callable, optional): It accepts four arguments: + index of data source, the length of minibatch, the size of + minibatch so far and data source, and it returns the size of + mini batch so far. Actually, the returned value can be anything + and would used as argument size_so_far in `key`. If None, it + would return the length of mini match. Default: None. + key (callable, optional): It accepts the size of minibatch so far + and the length of minibatch, and returns what to be compared + with `batch_size`. If None, only the size of mini batch so far + would be compared with `batch_size`. Default: None. Returns: SamplerHelper """ + _key = lambda size_so_far, minibatch_len: size_so_far + + ori_batch_size_fn = batch_size_fn if batch_size_fn is None: - ori_batch_size_fn = None batch_size_fn = lambda new, count, sofar, data_source: count + key = _key if key is None else key def _impl(): data_source = self.data_source @@ -197,20 +205,22 @@ class SamplerHelper(object): size_so_far = batch_size_fn(idx, len(minibatch), size_so_far, data_source) - if size_so_far == batch_size: + if key(size_so_far, len(minibatch)) == batch_size: yield minibatch minibatch, size_so_far = [], 0 - elif size_so_far > batch_size: + elif key(size_so_far, len(minibatch)) > batch_size: + if len(minibatch) == 1: + raise ValueError( + "Please increase the value of `batch_size`, or limit the max length of batch." + ) yield minibatch[:-1] minibatch, size_so_far = minibatch[-1:], batch_size_fn( idx, 1, 0, data_source) if minibatch and not drop_last: yield minibatch - sampler = type(self)( - self.data_source, - _impl) if batch_fn is None else self.apply(batch_fn) - if ori_batch_size_fn is None and batch_fn is None and self.length is not None: + sampler = type(self)(self.data_source, _impl) + if ori_batch_size_fn is None and self.length is not None: sampler.length = (self.length + int(not drop_last) * (batch_size - 1)) // batch_size else: diff --git a/PaddleNLP/paddlenlp/datasets/translation.py b/PaddleNLP/paddlenlp/datasets/translation.py index 84ca80b914ff7c4dc9749c65071eee2b98424a3d..cebf0913ce564c645a307d961d78d9cdccdabce8 100644 --- a/PaddleNLP/paddlenlp/datasets/translation.py +++ b/PaddleNLP/paddlenlp/datasets/translation.py @@ -1,5 +1,6 @@ import os import io +import collections from functools import partial import numpy as np @@ -8,65 +9,12 @@ import paddle from paddle.utils.download import get_path_from_url from paddlenlp.data import Vocab, Pad from paddlenlp.data.sampler import SamplerHelper - -DATA_HOME = "/root/.paddlenlp/datasets" +from paddlenlp.utils.env import DATA_HOME +from paddle.dataset.common import md5file __all__ = ['TranslationDataset', 'IWSLT15'] -def read_raw_files(corpus_path): - """Read raw files, return raw data""" - data = [] - (f_mode, f_encoding, endl) = ("r", "utf-8", "\n") - with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus: - for line in f_corpus.readlines(): - data.append(line.strip()) - return data - - -def get_raw_data(data_dir, train_filenames, valid_filenames, test_filenames, - data_select): - data_dict = {} - file_select = { - 'train': train_filenames, - 'dev': valid_filenames, - 'test': test_filenames - } - for mode in data_select: - src_filename, tgt_filename = file_select[mode] - src_path = os.path.join(data_dir, src_filename) - tgt_path = os.path.join(data_dir, tgt_filename) - src_data = read_raw_files(src_path) - tgt_data = read_raw_files(tgt_path) - - data_dict[mode] = [(src_data[i], tgt_data[i]) - for i in range(len(src_data))] - return data_dict - - -def setup_datasets(train_filenames, - valid_filenames, - test_filenames, - data_select, - root=None): - # Input check - target_select = ('train', 'dev', 'test') - if isinstance(data_select, str): - data_select = (data_select, ) - if not set(data_select).issubset(set(target_select)): - raise TypeError( - 'A subset of data selection {} is supported but {} is passed in'. - format(target_select, data_select)) - - raw_data = get_raw_data(root, train_filenames, valid_filenames, - test_filenames, data_select) - - datasets = [] - for mode in data_select: - datasets.append(TranslationDataset(raw_data[mode])) - return tuple(datasets) - - def vocab_func(vocab, unk_token): def func(tok_iter): return [ @@ -103,13 +51,12 @@ class TranslationDataset(paddle.io.Dataset): data(list): Raw data. It is a list of tuple or list, each sample of data contains two element, source and target. """ + META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', + 'src_md5', 'tgt_md5')) + SPLITS = {} URL = None - train_filenames = (None, None) - valid_filenames = (None, None) - test_filenames = (None, None) - src_vocab_filename = None - tgt_vocab_filename = None - dataset_dirname = None + MD5 = None + VOCAB_INFO = None def __init__(self, data): self.data = data @@ -121,7 +68,7 @@ class TranslationDataset(paddle.io.Dataset): return len(self.data) @classmethod - def get_data(cls, root=None): + def get_data(cls, mode="train", root=None): """ Download dataset if any data file doesn't exist. Args: @@ -136,32 +83,37 @@ class TranslationDataset(paddle.io.Dataset): from paddlenlp.datasets import IWSLT15 data_path = IWSLT15.get_data() """ - if root is None: - root = os.path.join(DATA_HOME, 'machine_translation') - data_dir = os.path.join(root, cls.dataset_dirname) - if not os.path.exists(root): - os.makedirs(root) - print("IWSLT will be downloaded at ", root) - get_path_from_url(cls.URL, root) - print("Downloaded success......") - else: - filename_list = [ - cls.train_filenames[0], cls.train_filenames[1], - cls.valid_filenames[0], cls.valid_filenames[0], - cls.src_vocab_filename, cls.tgt_vocab_filename - ] - for filename in filename_list: - file_path = os.path.join(data_dir, filename) - if not os.path.exists(file_path): - print( - "The dataset is incomplete and will be re-downloaded.") - get_path_from_url(cls.URL, root) - print("Downloaded success......") - break - return data_dir + default_root = os.path.join(DATA_HOME, 'machine_translation') + src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ + mode] + + filename_list = [ + src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1] + ] + fullname_list = [] + for filename in filename_list: + fullname = os.path.join(default_root, + filename) if root is None else os.path.join( + os.path.expanduser(root), filename) + fullname_list.append(fullname) + + data_hash_list = [ + src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] + ] + for i, fullname in enumerate(fullname_list): + if not os.path.exists(fullname) or ( + data_hash_list[i] and + not md5file(fullname) == data_hash_list[i]): + if root is not None: # not specified, and no need to warn + warnings.warn( + 'md5 check failed for {}, download {} data to {}'. + format(filename, self.__class__.__name__, default_root)) + path = get_path_from_url(cls.URL, default_root, cls.MD5) + break + return root if root is not None else default_root @classmethod - def get_vocab(cls, root=None): + def build_vocab(cls, root=None): """ Load vocab from vocab files. It vocab files don't exist, the will be downloaded. @@ -176,22 +128,42 @@ class TranslationDataset(paddle.io.Dataset): Examples: .. code-block:: python from paddlenlp.datasets import IWSLT15 - (src_vocab, tgt_vocab) = IWSLT15.get_vocab() + (src_vocab, tgt_vocab) = IWSLT15.build_vocab() """ - data_path = cls.get_data(root) - + root = cls.get_data(root=root) # Get vocab_func - src_file_path = os.path.join(data_path, cls.src_vocab_filename) - tgt_file_path = os.path.join(data_path, cls.tgt_vocab_filename) + src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO + src_file_path = os.path.join(root, src_vocab_filename) + tgt_file_path = os.path.join(root, tgt_vocab_filename) - src_vocab = Vocab.load_vocabulary(src_file_path, cls.unk_token, - cls.bos_token, cls.eos_token) + src_vocab = Vocab.load_vocabulary(src_file_path, cls.UNK_TOKEN, + cls.BOS_TOKEN, cls.EOS_TOKEN) - tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.unk_token, - cls.bos_token, cls.eos_token) + tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.UNK_TOKEN, + cls.BOS_TOKEN, cls.EOS_TOKEN) return (src_vocab, tgt_vocab) + def read_raw_data(self, data_dir, mode): + src_filename, tgt_filename, _, _ = self.SPLITS[mode] + + def read_raw_files(corpus_path): + """Read raw files, return raw data""" + data = [] + (f_mode, f_encoding, endl) = ("r", "utf-8", "\n") + with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus: + for line in f_corpus.readlines(): + data.append(line.strip()) + return data + + src_path = os.path.join(data_dir, src_filename) + tgt_path = os.path.join(data_dir, tgt_filename) + src_data = read_raw_files(src_path) + tgt_data = read_raw_files(tgt_path) + + data = [(src_data[i], tgt_data[i]) for i in range(len(src_data))] + return data + @classmethod def get_default_transform_func(cls, root=None): """Get default transform function, which transforms raw data to id. @@ -210,11 +182,11 @@ class TranslationDataset(paddle.io.Dataset): src_text_vocab_transform = sequential_transforms(src_tokenizer) tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer) - (src_vocab, tgt_vocab) = cls.get_vocab(root) + (src_vocab, tgt_vocab) = cls.build_vocab(root) src_text_transform = sequential_transforms( - src_text_vocab_transform, vocab_func(src_vocab, cls.unk_token)) + src_text_vocab_transform, vocab_func(src_vocab, cls.UNK_TOKEN)) tgt_text_transform = sequential_transforms( - tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.unk_token)) + tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.UNK_TOKEN)) return (src_text_transform, tgt_text_transform) @@ -235,32 +207,45 @@ class IWSLT15(TranslationDataset): """ URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz" - train_filenames = ("train.en", "train.vi") - valid_filenames = ("tst2012.en", "tst2012.vi") - test_filenames = ("tst2013.en", "tst2013.vi") - src_vocab_filename = "vocab.en" - tgt_vocab_filename = "vocab.vi" - unk_token = '' - bos_token = '' - eos_token = '' - dataset_dirname = "iwslt15.en-vi" + SPLITS = { + 'train': TranslationDataset.META_INFO( + os.path.join("iwslt15.en-vi", "train.en"), + os.path.join("iwslt15.en-vi", "train.vi"), + "5b6300f46160ab5a7a995546d2eeb9e6", + "858e884484885af5775068140ae85dab"), + 'dev': TranslationDataset.META_INFO( + os.path.join("iwslt15.en-vi", "tst2012.en"), + os.path.join("iwslt15.en-vi", "tst2012.vi"), + "c14a0955ed8b8d6929fdabf4606e3875", + "dddf990faa149e980b11a36fca4a8898"), + 'test': TranslationDataset.META_INFO( + os.path.join("iwslt15.en-vi", "tst2013.en"), + os.path.join("iwslt15.en-vi", "tst2013.vi"), + "c41c43cb6d3b122c093ee89608ba62bd", + "a3185b00264620297901b647a4cacf38") + } + VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join( + "iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4", + "e8b05f8c26008a798073c619236712b4") + UNK_TOKEN = '' + BOS_TOKEN = '' + EOS_TOKEN = '' + MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e' def __init__(self, mode='train', root=None, transform_func=None): - # Input check - segment_select = ('train', 'dev', 'test') - if mode not in segment_select: + data_select = ('train', 'dev', 'test') + if mode not in data_select: raise TypeError( '`train`, `dev` or `test` is supported but `{}` is passed in'. format(mode)) if transform_func is not None: if len(transform_func) != 2: raise ValueError("`transform_func` must have length of two for" - "source and target") + "source and target.") # Download data - data_path = IWSLT15.get_data(root) - dataset = setup_datasets(self.train_filenames, self.valid_filenames, - self.test_filenames, [mode], data_path)[0] - self.data = dataset.data + root = IWSLT15.get_data(root=root) + self.data = self.read_raw_data(root, mode) + if transform_func is not None: self.data = [(transform_func[0](data[0]), transform_func[1](data[1])) for data in self.data] @@ -274,18 +259,25 @@ def prepare_train_input(insts, pad_id): [inst[1] for inst in insts]) return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis] +batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(size_so_far, len(data_source[idx][0])) + +batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len if __name__ == '__main__': - batch_size = 32 + batch_size = 4096 #32 pad_id = 2 transform_func = IWSLT15.get_default_transform_func() train_dataset = IWSLT15(transform_func=transform_func) key = (lambda x, data_source: len(data_source[x][0])) + train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort( key=key, buffer_size=batch_size * 20).batch( - batch_size=batch_size, drop_last=True).shard() + batch_size=batch_size, + drop_last=True, + batch_size_fn=batch_size_fn, + key=batch_key).shard() train_loader = paddle.io.DataLoader( train_dataset, @@ -293,6 +285,7 @@ if __name__ == '__main__': collate_fn=partial( prepare_train_input, pad_id=pad_id)) - for data in train_loader: - print(data) - break + for i, data in enumerate(train_loader): + print(data[1]) + print(paddle.max(data[1]) * len(data[1])) + print(len(data[1]))