未验证 提交 ca2ecccf 编写于 作者: L LiuChiachi 提交者: GitHub

Update SamplerHelper batch, Update TranslationDataset (#4991)

* update SamplerHelper.batch

* update batch_size, delete a useless log

* update SamplerHelper.batch

* update SamplerHelper.batch, change cmp_fn arg to key

* Update TranslationDataset and IWSLT

* fix doc error
上级 8675c9fa
...@@ -170,24 +170,32 @@ class SamplerHelper(object): ...@@ -170,24 +170,32 @@ class SamplerHelper(object):
return type(self)(self.data_source, _impl) return type(self)(self.data_source, _impl)
def batch(self, def batch(self, batch_size, drop_last=False, batch_size_fn=None, key=None):
batch_size,
drop_last=False,
batch_size_fn=None,
batch_fn=None):
""" """
To produce a BatchSampler. To produce a BatchSampler.
Agrs: Args:
batch_size (int): Batch size. batch_size (int): Batch size.
drop_last (bool): Whether to drop the last mini batch. Default: False. drop_last (bool): Whether to drop the last mini batch. Default:
batch_size_fn (callable, optional): Return the size of mini batch so far. Default: None. False.
batch_fn (callable, optional): Transformations to be performed. Default: None. batch_size_fn (callable, optional): It accepts four arguments:
index of data source, the length of minibatch, the size of
minibatch so far and data source, and it returns the size of
mini batch so far. Actually, the returned value can be anything
and would used as argument size_so_far in `key`. If None, it
would return the length of mini match. Default: None.
key (callable, optional): It accepts the size of minibatch so far
and the length of minibatch, and returns what to be compared
with `batch_size`. If None, only the size of mini batch so far
would be compared with `batch_size`. Default: None.
Returns: Returns:
SamplerHelper SamplerHelper
""" """
_key = lambda size_so_far, minibatch_len: size_so_far
ori_batch_size_fn = batch_size_fn
if batch_size_fn is None: if batch_size_fn is None:
ori_batch_size_fn = None
batch_size_fn = lambda new, count, sofar, data_source: count batch_size_fn = lambda new, count, sofar, data_source: count
key = _key if key is None else key
def _impl(): def _impl():
data_source = self.data_source data_source = self.data_source
...@@ -197,20 +205,22 @@ class SamplerHelper(object): ...@@ -197,20 +205,22 @@ class SamplerHelper(object):
size_so_far = batch_size_fn(idx, size_so_far = batch_size_fn(idx,
len(minibatch), size_so_far, len(minibatch), size_so_far,
data_source) data_source)
if size_so_far == batch_size: if key(size_so_far, len(minibatch)) == batch_size:
yield minibatch yield minibatch
minibatch, size_so_far = [], 0 minibatch, size_so_far = [], 0
elif size_so_far > batch_size: elif key(size_so_far, len(minibatch)) > batch_size:
if len(minibatch) == 1:
raise ValueError(
"Please increase the value of `batch_size`, or limit the max length of batch."
)
yield minibatch[:-1] yield minibatch[:-1]
minibatch, size_so_far = minibatch[-1:], batch_size_fn( minibatch, size_so_far = minibatch[-1:], batch_size_fn(
idx, 1, 0, data_source) idx, 1, 0, data_source)
if minibatch and not drop_last: if minibatch and not drop_last:
yield minibatch yield minibatch
sampler = type(self)( sampler = type(self)(self.data_source, _impl)
self.data_source, if ori_batch_size_fn is None and self.length is not None:
_impl) if batch_fn is None else self.apply(batch_fn)
if ori_batch_size_fn is None and batch_fn is None and self.length is not None:
sampler.length = (self.length + int(not drop_last) * sampler.length = (self.length + int(not drop_last) *
(batch_size - 1)) // batch_size (batch_size - 1)) // batch_size
else: else:
......
import os import os
import io import io
import collections
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -8,65 +9,12 @@ import paddle ...@@ -8,65 +9,12 @@ import paddle
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad from paddlenlp.data import Vocab, Pad
from paddlenlp.data.sampler import SamplerHelper from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.utils.env import DATA_HOME
DATA_HOME = "/root/.paddlenlp/datasets" from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15'] __all__ = ['TranslationDataset', 'IWSLT15']
def read_raw_files(corpus_path):
"""Read raw files, return raw data"""
data = []
(f_mode, f_encoding, endl) = ("r", "utf-8", "\n")
with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus:
for line in f_corpus.readlines():
data.append(line.strip())
return data
def get_raw_data(data_dir, train_filenames, valid_filenames, test_filenames,
data_select):
data_dict = {}
file_select = {
'train': train_filenames,
'dev': valid_filenames,
'test': test_filenames
}
for mode in data_select:
src_filename, tgt_filename = file_select[mode]
src_path = os.path.join(data_dir, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
data_dict[mode] = [(src_data[i], tgt_data[i])
for i in range(len(src_data))]
return data_dict
def setup_datasets(train_filenames,
valid_filenames,
test_filenames,
data_select,
root=None):
# Input check
target_select = ('train', 'dev', 'test')
if isinstance(data_select, str):
data_select = (data_select, )
if not set(data_select).issubset(set(target_select)):
raise TypeError(
'A subset of data selection {} is supported but {} is passed in'.
format(target_select, data_select))
raw_data = get_raw_data(root, train_filenames, valid_filenames,
test_filenames, data_select)
datasets = []
for mode in data_select:
datasets.append(TranslationDataset(raw_data[mode]))
return tuple(datasets)
def vocab_func(vocab, unk_token): def vocab_func(vocab, unk_token):
def func(tok_iter): def func(tok_iter):
return [ return [
...@@ -103,13 +51,12 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -103,13 +51,12 @@ class TranslationDataset(paddle.io.Dataset):
data(list): Raw data. It is a list of tuple or list, each sample of data(list): Raw data. It is a list of tuple or list, each sample of
data contains two element, source and target. data contains two element, source and target.
""" """
META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file',
'src_md5', 'tgt_md5'))
SPLITS = {}
URL = None URL = None
train_filenames = (None, None) MD5 = None
valid_filenames = (None, None) VOCAB_INFO = None
test_filenames = (None, None)
src_vocab_filename = None
tgt_vocab_filename = None
dataset_dirname = None
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
...@@ -121,7 +68,7 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -121,7 +68,7 @@ class TranslationDataset(paddle.io.Dataset):
return len(self.data) return len(self.data)
@classmethod @classmethod
def get_data(cls, root=None): def get_data(cls, mode="train", root=None):
""" """
Download dataset if any data file doesn't exist. Download dataset if any data file doesn't exist.
Args: Args:
...@@ -136,32 +83,37 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -136,32 +83,37 @@ class TranslationDataset(paddle.io.Dataset):
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data() data_path = IWSLT15.get_data()
""" """
if root is None: default_root = os.path.join(DATA_HOME, 'machine_translation')
root = os.path.join(DATA_HOME, 'machine_translation') src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
data_dir = os.path.join(root, cls.dataset_dirname) mode]
if not os.path.exists(root):
os.makedirs(root) filename_list = [
print("IWSLT will be downloaded at ", root) src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1]
get_path_from_url(cls.URL, root) ]
print("Downloaded success......") fullname_list = []
else: for filename in filename_list:
filename_list = [ fullname = os.path.join(default_root,
cls.train_filenames[0], cls.train_filenames[1], filename) if root is None else os.path.join(
cls.valid_filenames[0], cls.valid_filenames[0], os.path.expanduser(root), filename)
cls.src_vocab_filename, cls.tgt_vocab_filename fullname_list.append(fullname)
]
for filename in filename_list: data_hash_list = [
file_path = os.path.join(data_dir, filename) src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
if not os.path.exists(file_path): ]
print( for i, fullname in enumerate(fullname_list):
"The dataset is incomplete and will be re-downloaded.") if not os.path.exists(fullname) or (
get_path_from_url(cls.URL, root) data_hash_list[i] and
print("Downloaded success......") not md5file(fullname) == data_hash_list[i]):
break if root is not None: # not specified, and no need to warn
return data_dir warnings.warn(
'md5 check failed for {}, download {} data to {}'.
format(filename, self.__class__.__name__, default_root))
path = get_path_from_url(cls.URL, default_root, cls.MD5)
break
return root if root is not None else default_root
@classmethod @classmethod
def get_vocab(cls, root=None): def build_vocab(cls, root=None):
""" """
Load vocab from vocab files. It vocab files don't exist, the will Load vocab from vocab files. It vocab files don't exist, the will
be downloaded. be downloaded.
...@@ -176,22 +128,42 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -176,22 +128,42 @@ class TranslationDataset(paddle.io.Dataset):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.get_vocab() (src_vocab, tgt_vocab) = IWSLT15.build_vocab()
""" """
data_path = cls.get_data(root) root = cls.get_data(root=root)
# Get vocab_func # Get vocab_func
src_file_path = os.path.join(data_path, cls.src_vocab_filename) src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
tgt_file_path = os.path.join(data_path, cls.tgt_vocab_filename) src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary(src_file_path, cls.unk_token, src_vocab = Vocab.load_vocabulary(src_file_path, cls.UNK_TOKEN,
cls.bos_token, cls.eos_token) cls.BOS_TOKEN, cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.unk_token, tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.UNK_TOKEN,
cls.bos_token, cls.eos_token) cls.BOS_TOKEN, cls.EOS_TOKEN)
return (src_vocab, tgt_vocab) return (src_vocab, tgt_vocab)
def read_raw_data(self, data_dir, mode):
src_filename, tgt_filename, _, _ = self.SPLITS[mode]
def read_raw_files(corpus_path):
"""Read raw files, return raw data"""
data = []
(f_mode, f_encoding, endl) = ("r", "utf-8", "\n")
with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus:
for line in f_corpus.readlines():
data.append(line.strip())
return data
src_path = os.path.join(data_dir, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
data = [(src_data[i], tgt_data[i]) for i in range(len(src_data))]
return data
@classmethod @classmethod
def get_default_transform_func(cls, root=None): def get_default_transform_func(cls, root=None):
"""Get default transform function, which transforms raw data to id. """Get default transform function, which transforms raw data to id.
...@@ -210,11 +182,11 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -210,11 +182,11 @@ class TranslationDataset(paddle.io.Dataset):
src_text_vocab_transform = sequential_transforms(src_tokenizer) src_text_vocab_transform = sequential_transforms(src_tokenizer)
tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer) tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer)
(src_vocab, tgt_vocab) = cls.get_vocab(root) (src_vocab, tgt_vocab) = cls.build_vocab(root)
src_text_transform = sequential_transforms( src_text_transform = sequential_transforms(
src_text_vocab_transform, vocab_func(src_vocab, cls.unk_token)) src_text_vocab_transform, vocab_func(src_vocab, cls.UNK_TOKEN))
tgt_text_transform = sequential_transforms( tgt_text_transform = sequential_transforms(
tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.unk_token)) tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.UNK_TOKEN))
return (src_text_transform, tgt_text_transform) return (src_text_transform, tgt_text_transform)
...@@ -235,32 +207,45 @@ class IWSLT15(TranslationDataset): ...@@ -235,32 +207,45 @@ class IWSLT15(TranslationDataset):
""" """
URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz" URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz"
train_filenames = ("train.en", "train.vi") SPLITS = {
valid_filenames = ("tst2012.en", "tst2012.vi") 'train': TranslationDataset.META_INFO(
test_filenames = ("tst2013.en", "tst2013.vi") os.path.join("iwslt15.en-vi", "train.en"),
src_vocab_filename = "vocab.en" os.path.join("iwslt15.en-vi", "train.vi"),
tgt_vocab_filename = "vocab.vi" "5b6300f46160ab5a7a995546d2eeb9e6",
unk_token = '<unk>' "858e884484885af5775068140ae85dab"),
bos_token = '<s>' 'dev': TranslationDataset.META_INFO(
eos_token = '</s>' os.path.join("iwslt15.en-vi", "tst2012.en"),
dataset_dirname = "iwslt15.en-vi" os.path.join("iwslt15.en-vi", "tst2012.vi"),
"c14a0955ed8b8d6929fdabf4606e3875",
"dddf990faa149e980b11a36fca4a8898"),
'test': TranslationDataset.META_INFO(
os.path.join("iwslt15.en-vi", "tst2013.en"),
os.path.join("iwslt15.en-vi", "tst2013.vi"),
"c41c43cb6d3b122c093ee89608ba62bd",
"a3185b00264620297901b647a4cacf38")
}
VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join(
"iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4",
"e8b05f8c26008a798073c619236712b4")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e'
def __init__(self, mode='train', root=None, transform_func=None): def __init__(self, mode='train', root=None, transform_func=None):
# Input check data_select = ('train', 'dev', 'test')
segment_select = ('train', 'dev', 'test') if mode not in data_select:
if mode not in segment_select:
raise TypeError( raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'. '`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode)) format(mode))
if transform_func is not None: if transform_func is not None:
if len(transform_func) != 2: if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for" raise ValueError("`transform_func` must have length of two for"
"source and target") "source and target.")
# Download data # Download data
data_path = IWSLT15.get_data(root) root = IWSLT15.get_data(root=root)
dataset = setup_datasets(self.train_filenames, self.valid_filenames, self.data = self.read_raw_data(root, mode)
self.test_filenames, [mode], data_path)[0]
self.data = dataset.data
if transform_func is not None: if transform_func is not None:
self.data = [(transform_func[0](data[0]), self.data = [(transform_func[0](data[0]),
transform_func[1](data[1])) for data in self.data] transform_func[1](data[1])) for data in self.data]
...@@ -274,18 +259,25 @@ def prepare_train_input(insts, pad_id): ...@@ -274,18 +259,25 @@ def prepare_train_input(insts, pad_id):
[inst[1] for inst in insts]) [inst[1] for inst in insts])
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis] return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis]
batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(size_so_far, len(data_source[idx][0]))
batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len
if __name__ == '__main__': if __name__ == '__main__':
batch_size = 32 batch_size = 4096 #32
pad_id = 2 pad_id = 2
transform_func = IWSLT15.get_default_transform_func() transform_func = IWSLT15.get_default_transform_func()
train_dataset = IWSLT15(transform_func=transform_func) train_dataset = IWSLT15(transform_func=transform_func)
key = (lambda x, data_source: len(data_source[x][0])) key = (lambda x, data_source: len(data_source[x][0]))
train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort( train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort(
key=key, buffer_size=batch_size * 20).batch( key=key, buffer_size=batch_size * 20).batch(
batch_size=batch_size, drop_last=True).shard() batch_size=batch_size,
drop_last=True,
batch_size_fn=batch_size_fn,
key=batch_key).shard()
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
...@@ -293,6 +285,7 @@ if __name__ == '__main__': ...@@ -293,6 +285,7 @@ if __name__ == '__main__':
collate_fn=partial( collate_fn=partial(
prepare_train_input, pad_id=pad_id)) prepare_train_input, pad_id=pad_id))
for data in train_loader: for i, data in enumerate(train_loader):
print(data) print(data[1])
break print(paddle.max(data[1]) * len(data[1]))
print(len(data[1]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册