import os import io import collections import warnings from functools import partial import numpy as np import paddle from paddle.utils.download import get_path_from_url from paddlenlp.data import Vocab, Pad from paddlenlp.data.sampler import SamplerHelper from paddlenlp.utils.env import DATA_HOME from paddle.dataset.common import md5file __all__ = ['TranslationDataset', 'IWSLT15'] def sequential_transforms(*transforms): def func(txt_input): for transform in transforms: txt_input = transform(txt_input) return txt_input return func def get_default_tokenizer(): """Only support split tokenizer """ def _split_tokenizer(x): return x.split() return _split_tokenizer class TranslationDataset(paddle.io.Dataset): """ TranslationDataset, provide tuple (source and target) raw data. Args: data(list): Raw data. It is a list of tuple or list, each sample of data contains two element, source and target. """ META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', 'src_md5', 'tgt_md5')) SPLITS = {} URL = None MD5 = None VOCAB_INFO = None UNK_TOKEN = None BOS_TOKEN = None EOS_TOKEN = None PAD_TOKEN = None def __init__(self, data): self.data = data def __getitem__(self, idx): return self.data[idx] def __len__(self): return len(self.data) @classmethod def get_data(cls, mode="train", root=None): """ Download dataset and read raw data. Args: mode(str, optional): Data mode to download. It could be 'train', 'dev' or 'test'. Default: 'train'. root (str, optional): data directory to save dataset. If not provided, dataset will be saved in `/root/.paddlenlp/datasets/machine_translation`. Default: None. Returns: list: Raw data, a list of tuple. Examples: .. code-block:: python from paddlenlp.datasets import IWSLT15 data_path = IWSLT15.get_data() """ root = cls._download_data(mode, root) data = cls.read_raw_data(mode, root) return data @classmethod def _download_data(cls, mode="train", root=None): """Download dataset""" default_root = os.path.join(DATA_HOME, 'machine_translation', cls.__name__) src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ mode] filename_list = [ src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1] ] fullname_list = [] for filename in filename_list: fullname = os.path.join(default_root, filename) if root is None else os.path.join( os.path.expanduser(root), filename) fullname_list.append(fullname) data_hash_list = [ src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] ] for i, fullname in enumerate(fullname_list): if not os.path.exists(fullname) or ( data_hash_list[i] and not md5file(fullname) == data_hash_list[i]): if root is not None: # not specified, and no need to warn warnings.warn( 'md5 check failed for {}, download {} data to {}'. format(filename, cls.__name__, default_root)) path = get_path_from_url(cls.URL, default_root, cls.MD5) return default_root return root if root is not None else default_root @classmethod def get_vocab(cls, root=None): """ Load vocab from vocab files. It vocab files don't exist, the will be downloaded. Args: root (str, optional): Data directory to save dataset. If not provided, dataset will be save in `/root/.paddlenlp/datasets/machine_translation`. If vocab files exist, they won't be overwritten. Default: None. Returns: tuple: Source vocab and target vocab. Examples: .. code-block:: python from paddlenlp.datasets import IWSLT15 (src_vocab, tgt_vocab) = IWSLT15.get_vocab() """ root = cls._download_data(root=root) src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO src_file_path = os.path.join(root, src_vocab_filename) tgt_file_path = os.path.join(root, tgt_vocab_filename) src_vocab = Vocab.load_vocabulary( src_file_path, unk_token=cls.UNK_TOKEN, pad_token=cls.PAD_TOKEN, bos_token=cls.BOS_TOKEN, eos_token=cls.EOS_TOKEN) tgt_vocab = Vocab.load_vocabulary( tgt_file_path, unk_token=cls.UNK_TOKEN, pad_token=cls.PAD_TOKEN, bos_token=cls.BOS_TOKEN, eos_token=cls.EOS_TOKEN) return (src_vocab, tgt_vocab) @classmethod def read_raw_data(cls, mode, root): """Read raw data from data files Args: mode(str): Indicates the mode to read. It could be 'train', 'dev' or 'test'. root(str): Data directory of dataset. Returns: list: Raw data list. """ src_filename, tgt_filename, _, _ = cls.SPLITS[mode] def read_raw_files(corpus_path): """Read raw files, return raw data""" data = [] (f_mode, f_encoding, endl) = ("r", "utf-8", "\n") with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus: for line in f_corpus.readlines(): data.append(line.strip()) return data src_path = os.path.join(root, src_filename) tgt_path = os.path.join(root, tgt_filename) src_data = read_raw_files(src_path) tgt_data = read_raw_files(tgt_path) data = [(src_data[i], tgt_data[i]) for i in range(len(src_data))] return data @classmethod def get_default_transform_func(cls, root=None): """Get default transform function, which transforms raw data to id. Args: root(str, optional): Data directory of dataset. Returns: tuple: Two transform functions, for source and target data. Examples: .. code-block:: python from paddlenlp.datasets import IWSLT15 transform_func = IWSLT15.get_default_transform_func() """ # Get default tokenizer src_tokenizer = get_default_tokenizer() tgt_tokenizer = get_default_tokenizer() src_text_vocab_transform = sequential_transforms(src_tokenizer) tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer) (src_vocab, tgt_vocab) = cls.get_vocab(root) src_text_transform = sequential_transforms(src_text_vocab_transform, src_vocab) tgt_text_transform = sequential_transforms(tgt_text_vocab_transform, tgt_vocab) return (src_text_transform, tgt_text_transform) class IWSLT15(TranslationDataset): """ IWSLT15 Vietnames to English translation dataset. Args: mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'. root(str, optional): If None, dataset will be downloaded in `/root/.paddlenlp/datasets/machine_translation`. Default: None. transform_func(callable, optional): If not None, it transforms raw data to index data. Default: None. Examples: .. code-block:: python from paddlenlp.datasets import IWSLT15 train_dataset = IWSLT15('train') train_dataset, valid_dataset = IWSLT15.get_datasets(["train", "dev"]) """ URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz" SPLITS = { 'train': TranslationDataset.META_INFO( os.path.join("iwslt15.en-vi", "train.en"), os.path.join("iwslt15.en-vi", "train.vi"), "5b6300f46160ab5a7a995546d2eeb9e6", "858e884484885af5775068140ae85dab"), 'dev': TranslationDataset.META_INFO( os.path.join("iwslt15.en-vi", "tst2012.en"), os.path.join("iwslt15.en-vi", "tst2012.vi"), "c14a0955ed8b8d6929fdabf4606e3875", "dddf990faa149e980b11a36fca4a8898"), 'test': TranslationDataset.META_INFO( os.path.join("iwslt15.en-vi", "tst2013.en"), os.path.join("iwslt15.en-vi", "tst2013.vi"), "c41c43cb6d3b122c093ee89608ba62bd", "a3185b00264620297901b647a4cacf38") } VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join( "iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4", "e8b05f8c26008a798073c619236712b4") UNK_TOKEN = '' BOS_TOKEN = '' EOS_TOKEN = '' MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e' def __init__(self, mode='train', root=None, transform_func=None): data_select = ('train', 'dev', 'test') if mode not in data_select: raise TypeError( '`train`, `dev` or `test` is supported but `{}` is passed in'. format(mode)) if transform_func is not None: if len(transform_func) != 2: raise ValueError("`transform_func` must have length of two for" "source and target.") # Download data and read data self.data = self.get_data(root=root) if transform_func is not None: self.data = [(transform_func[0](data[0]), transform_func[1](data[1])) for data in self.data] # For test, not API def prepare_train_input(insts, pad_id): src, src_length = Pad(pad_val=pad_id, ret_length=True)( [inst[0] for inst in insts]) tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)( [inst[1] for inst in insts]) return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis] batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(size_so_far, len(data_source[idx][0])) batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len if __name__ == '__main__': batch_size = 4096 #32 pad_id = 2 transform_func = IWSLT15.get_default_transform_func() train_dataset = IWSLT15(transform_func=transform_func) key = (lambda x, data_source: len(data_source[x][0])) train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort( key=key, buffer_size=batch_size * 20).batch( batch_size=batch_size, drop_last=True, batch_size_fn=batch_size_fn, key=batch_key).shard() train_loader = paddle.io.DataLoader( train_dataset, batch_sampler=train_batch_sampler, collate_fn=partial( prepare_train_input, pad_id=pad_id)) for i, data in enumerate(train_loader): print(data[1]) print(paddle.max(data[1]) * len(data[1])) print(len(data[1]))