未验证 提交 d4e1f1a4 编写于 作者: L liu zhengxi 提交者: GitHub

fix reader (#5154)

上级 cbbd8144
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import sys
import os
import io
......@@ -21,6 +20,7 @@ from functools import partial
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
import paddle.distributed as dist
from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
......@@ -47,52 +47,23 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
if args.sort_type == SortType.GLOBAL:
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(key=trg_key).sort(key=src_key)
else:
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL:
sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
batch_sampler = sampler.batch(
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=True,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=dist.get_world_size(),
rank=dist.get_rank())
data_loader = DataLoader(
dataset=dataset,
......@@ -167,3 +138,154 @@ class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self._batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self._batch) + 1) > self._batch_size:
result = self._batch
self._batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self._batch.append(info)
@property
def batch(self):
return self._batch
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# Take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0] + 1
self.trg_len = lens[1] + 1
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
self._sample_infos = []
for i, data in enumerate(self._dataset):
lens = [len(data[0]), len(data[1])]
self._sample_infos.append(SampleInfo(i, lens))
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# To avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# When producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# For uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import sys
import os
import io
......@@ -21,6 +20,7 @@ from functools import partial
import numpy as np
from paddle.io import BatchSampler, DataLoader, Dataset
import paddle.distributed as dist
from paddlenlp.data import Pad
from paddlenlp.datasets import WMT14ende
from paddlenlp.data.sampler import SamplerHelper
......@@ -47,52 +47,23 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)
def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len
data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
m = dataset.mode
dataset = dataset.filter(
partial(
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
if args.sort_type == SortType.GLOBAL:
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(key=trg_key).sort(key=src_key)
else:
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL:
sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
batch_sampler = sampler.batch(
batch_sampler = TransformerBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=True,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=dist.get_world_size(),
rank=dist.get_rank())
data_loader = DataLoader(
dataset=dataset,
......@@ -167,3 +138,154 @@ class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self._batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self._batch) + 1) > self._batch_size:
result = self._batch
self._batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self._batch.append(info)
@property
def batch(self):
return self._batch
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
# Take bos and eos into account
self.min_len = min(lens[0] + 1, lens[1] + 1)
self.max_len = max(lens[0] + 1, lens[1] + 1)
self.src_len = lens[0] + 1
self.trg_len = lens[1] + 1
class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
self._sample_infos = []
for i, data in enumerate(self._dataset):
lens = [len(data[0]), len(data[1])]
self._sample_infos.append(SampleInfo(i, lens))
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# To avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self.
_batch_size) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# When producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks
# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# For uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册