未验证 提交 ca2ecccf 编写于 作者: L LiuChiachi 提交者: GitHub

Update SamplerHelper batch, Update TranslationDataset (#4991)

* update SamplerHelper.batch

* update batch_size, delete a useless log

* update SamplerHelper.batch

* update SamplerHelper.batch, change cmp_fn arg to key

* Update TranslationDataset and IWSLT

* fix doc error
上级 8675c9fa
......@@ -170,24 +170,32 @@ class SamplerHelper(object):
return type(self)(self.data_source, _impl)
def batch(self,
batch_size,
drop_last=False,
batch_size_fn=None,
batch_fn=None):
def batch(self, batch_size, drop_last=False, batch_size_fn=None, key=None):
"""
To produce a BatchSampler.
Agrs:
Args:
batch_size (int): Batch size.
drop_last (bool): Whether to drop the last mini batch. Default: False.
batch_size_fn (callable, optional): Return the size of mini batch so far. Default: None.
batch_fn (callable, optional): Transformations to be performed. Default: None.
drop_last (bool): Whether to drop the last mini batch. Default:
False.
batch_size_fn (callable, optional): It accepts four arguments:
index of data source, the length of minibatch, the size of
minibatch so far and data source, and it returns the size of
mini batch so far. Actually, the returned value can be anything
and would used as argument size_so_far in `key`. If None, it
would return the length of mini match. Default: None.
key (callable, optional): It accepts the size of minibatch so far
and the length of minibatch, and returns what to be compared
with `batch_size`. If None, only the size of mini batch so far
would be compared with `batch_size`. Default: None.
Returns:
SamplerHelper
"""
_key = lambda size_so_far, minibatch_len: size_so_far
ori_batch_size_fn = batch_size_fn
if batch_size_fn is None:
ori_batch_size_fn = None
batch_size_fn = lambda new, count, sofar, data_source: count
key = _key if key is None else key
def _impl():
data_source = self.data_source
......@@ -197,20 +205,22 @@ class SamplerHelper(object):
size_so_far = batch_size_fn(idx,
len(minibatch), size_so_far,
data_source)
if size_so_far == batch_size:
if key(size_so_far, len(minibatch)) == batch_size:
yield minibatch
minibatch, size_so_far = [], 0
elif size_so_far > batch_size:
elif key(size_so_far, len(minibatch)) > batch_size:
if len(minibatch) == 1:
raise ValueError(
"Please increase the value of `batch_size`, or limit the max length of batch."
)
yield minibatch[:-1]
minibatch, size_so_far = minibatch[-1:], batch_size_fn(
idx, 1, 0, data_source)
if minibatch and not drop_last:
yield minibatch
sampler = type(self)(
self.data_source,
_impl) if batch_fn is None else self.apply(batch_fn)
if ori_batch_size_fn is None and batch_fn is None and self.length is not None:
sampler = type(self)(self.data_source, _impl)
if ori_batch_size_fn is None and self.length is not None:
sampler.length = (self.length + int(not drop_last) *
(batch_size - 1)) // batch_size
else:
......
import os
import io
import collections
from functools import partial
import numpy as np
......@@ -8,65 +9,12 @@ import paddle
from paddle.utils.download import get_path_from_url
from paddlenlp.data import Vocab, Pad
from paddlenlp.data.sampler import SamplerHelper
DATA_HOME = "/root/.paddlenlp/datasets"
from paddlenlp.utils.env import DATA_HOME
from paddle.dataset.common import md5file
__all__ = ['TranslationDataset', 'IWSLT15']
def read_raw_files(corpus_path):
"""Read raw files, return raw data"""
data = []
(f_mode, f_encoding, endl) = ("r", "utf-8", "\n")
with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus:
for line in f_corpus.readlines():
data.append(line.strip())
return data
def get_raw_data(data_dir, train_filenames, valid_filenames, test_filenames,
data_select):
data_dict = {}
file_select = {
'train': train_filenames,
'dev': valid_filenames,
'test': test_filenames
}
for mode in data_select:
src_filename, tgt_filename = file_select[mode]
src_path = os.path.join(data_dir, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
data_dict[mode] = [(src_data[i], tgt_data[i])
for i in range(len(src_data))]
return data_dict
def setup_datasets(train_filenames,
valid_filenames,
test_filenames,
data_select,
root=None):
# Input check
target_select = ('train', 'dev', 'test')
if isinstance(data_select, str):
data_select = (data_select, )
if not set(data_select).issubset(set(target_select)):
raise TypeError(
'A subset of data selection {} is supported but {} is passed in'.
format(target_select, data_select))
raw_data = get_raw_data(root, train_filenames, valid_filenames,
test_filenames, data_select)
datasets = []
for mode in data_select:
datasets.append(TranslationDataset(raw_data[mode]))
return tuple(datasets)
def vocab_func(vocab, unk_token):
def func(tok_iter):
return [
......@@ -103,13 +51,12 @@ class TranslationDataset(paddle.io.Dataset):
data(list): Raw data. It is a list of tuple or list, each sample of
data contains two element, source and target.
"""
META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file',
'src_md5', 'tgt_md5'))
SPLITS = {}
URL = None
train_filenames = (None, None)
valid_filenames = (None, None)
test_filenames = (None, None)
src_vocab_filename = None
tgt_vocab_filename = None
dataset_dirname = None
MD5 = None
VOCAB_INFO = None
def __init__(self, data):
self.data = data
......@@ -121,7 +68,7 @@ class TranslationDataset(paddle.io.Dataset):
return len(self.data)
@classmethod
def get_data(cls, root=None):
def get_data(cls, mode="train", root=None):
"""
Download dataset if any data file doesn't exist.
Args:
......@@ -136,32 +83,37 @@ class TranslationDataset(paddle.io.Dataset):
from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data()
"""
if root is None:
root = os.path.join(DATA_HOME, 'machine_translation')
data_dir = os.path.join(root, cls.dataset_dirname)
if not os.path.exists(root):
os.makedirs(root)
print("IWSLT will be downloaded at ", root)
get_path_from_url(cls.URL, root)
print("Downloaded success......")
else:
filename_list = [
cls.train_filenames[0], cls.train_filenames[1],
cls.valid_filenames[0], cls.valid_filenames[0],
cls.src_vocab_filename, cls.tgt_vocab_filename
]
for filename in filename_list:
file_path = os.path.join(data_dir, filename)
if not os.path.exists(file_path):
print(
"The dataset is incomplete and will be re-downloaded.")
get_path_from_url(cls.URL, root)
print("Downloaded success......")
break
return data_dir
default_root = os.path.join(DATA_HOME, 'machine_translation')
src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
mode]
filename_list = [
src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1]
]
fullname_list = []
for filename in filename_list:
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
fullname_list.append(fullname)
data_hash_list = [
src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
]
for i, fullname in enumerate(fullname_list):
if not os.path.exists(fullname) or (
data_hash_list[i] and
not md5file(fullname) == data_hash_list[i]):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.
format(filename, self.__class__.__name__, default_root))
path = get_path_from_url(cls.URL, default_root, cls.MD5)
break
return root if root is not None else default_root
@classmethod
def get_vocab(cls, root=None):
def build_vocab(cls, root=None):
"""
Load vocab from vocab files. It vocab files don't exist, the will
be downloaded.
......@@ -176,22 +128,42 @@ class TranslationDataset(paddle.io.Dataset):
Examples:
.. code-block:: python
from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.get_vocab()
(src_vocab, tgt_vocab) = IWSLT15.build_vocab()
"""
data_path = cls.get_data(root)
root = cls.get_data(root=root)
# Get vocab_func
src_file_path = os.path.join(data_path, cls.src_vocab_filename)
tgt_file_path = os.path.join(data_path, cls.tgt_vocab_filename)
src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename)
src_vocab = Vocab.load_vocabulary(src_file_path, cls.unk_token,
cls.bos_token, cls.eos_token)
src_vocab = Vocab.load_vocabulary(src_file_path, cls.UNK_TOKEN,
cls.BOS_TOKEN, cls.EOS_TOKEN)
tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.unk_token,
cls.bos_token, cls.eos_token)
tgt_vocab = Vocab.load_vocabulary(tgt_file_path, cls.UNK_TOKEN,
cls.BOS_TOKEN, cls.EOS_TOKEN)
return (src_vocab, tgt_vocab)
def read_raw_data(self, data_dir, mode):
src_filename, tgt_filename, _, _ = self.SPLITS[mode]
def read_raw_files(corpus_path):
"""Read raw files, return raw data"""
data = []
(f_mode, f_encoding, endl) = ("r", "utf-8", "\n")
with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus:
for line in f_corpus.readlines():
data.append(line.strip())
return data
src_path = os.path.join(data_dir, src_filename)
tgt_path = os.path.join(data_dir, tgt_filename)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
data = [(src_data[i], tgt_data[i]) for i in range(len(src_data))]
return data
@classmethod
def get_default_transform_func(cls, root=None):
"""Get default transform function, which transforms raw data to id.
......@@ -210,11 +182,11 @@ class TranslationDataset(paddle.io.Dataset):
src_text_vocab_transform = sequential_transforms(src_tokenizer)
tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer)
(src_vocab, tgt_vocab) = cls.get_vocab(root)
(src_vocab, tgt_vocab) = cls.build_vocab(root)
src_text_transform = sequential_transforms(
src_text_vocab_transform, vocab_func(src_vocab, cls.unk_token))
src_text_vocab_transform, vocab_func(src_vocab, cls.UNK_TOKEN))
tgt_text_transform = sequential_transforms(
tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.unk_token))
tgt_text_vocab_transform, vocab_func(tgt_vocab, cls.UNK_TOKEN))
return (src_text_transform, tgt_text_transform)
......@@ -235,32 +207,45 @@ class IWSLT15(TranslationDataset):
"""
URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz"
train_filenames = ("train.en", "train.vi")
valid_filenames = ("tst2012.en", "tst2012.vi")
test_filenames = ("tst2013.en", "tst2013.vi")
src_vocab_filename = "vocab.en"
tgt_vocab_filename = "vocab.vi"
unk_token = '<unk>'
bos_token = '<s>'
eos_token = '</s>'
dataset_dirname = "iwslt15.en-vi"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("iwslt15.en-vi", "train.en"),
os.path.join("iwslt15.en-vi", "train.vi"),
"5b6300f46160ab5a7a995546d2eeb9e6",
"858e884484885af5775068140ae85dab"),
'dev': TranslationDataset.META_INFO(
os.path.join("iwslt15.en-vi", "tst2012.en"),
os.path.join("iwslt15.en-vi", "tst2012.vi"),
"c14a0955ed8b8d6929fdabf4606e3875",
"dddf990faa149e980b11a36fca4a8898"),
'test': TranslationDataset.META_INFO(
os.path.join("iwslt15.en-vi", "tst2013.en"),
os.path.join("iwslt15.en-vi", "tst2013.vi"),
"c41c43cb6d3b122c093ee89608ba62bd",
"a3185b00264620297901b647a4cacf38")
}
VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join(
"iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4",
"e8b05f8c26008a798073c619236712b4")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e'
def __init__(self, mode='train', root=None, transform_func=None):
# Input check
segment_select = ('train', 'dev', 'test')
if mode not in segment_select:
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
if transform_func is not None:
if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for"
"source and target")
"source and target.")
# Download data
data_path = IWSLT15.get_data(root)
dataset = setup_datasets(self.train_filenames, self.valid_filenames,
self.test_filenames, [mode], data_path)[0]
self.data = dataset.data
root = IWSLT15.get_data(root=root)
self.data = self.read_raw_data(root, mode)
if transform_func is not None:
self.data = [(transform_func[0](data[0]),
transform_func[1](data[1])) for data in self.data]
......@@ -274,18 +259,25 @@ def prepare_train_input(insts, pad_id):
[inst[1] for inst in insts])
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis]
batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(size_so_far, len(data_source[idx][0]))
batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len
if __name__ == '__main__':
batch_size = 32
batch_size = 4096 #32
pad_id = 2
transform_func = IWSLT15.get_default_transform_func()
train_dataset = IWSLT15(transform_func=transform_func)
key = (lambda x, data_source: len(data_source[x][0]))
train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort(
key=key, buffer_size=batch_size * 20).batch(
batch_size=batch_size, drop_last=True).shard()
batch_size=batch_size,
drop_last=True,
batch_size_fn=batch_size_fn,
key=batch_key).shard()
train_loader = paddle.io.DataLoader(
train_dataset,
......@@ -293,6 +285,7 @@ if __name__ == '__main__':
collate_fn=partial(
prepare_train_input, pad_id=pad_id))
for data in train_loader:
print(data)
break
for i, data in enumerate(train_loader):
print(data[1])
print(paddle.max(data[1]) * len(data[1]))
print(len(data[1]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册