未验证 提交 c210e7ed 编写于 作者: L liu zhengxi 提交者: GitHub

update reader (#5133)

* cherry-pick #5129 
上级 63b738e3
......@@ -66,20 +66,17 @@ def create_data_loader(args):
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
sampler = sampler.sort(key=trg_key).sort(key=src_key)
else:
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
......@@ -87,12 +84,12 @@ def create_data_loader(args):
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
......
......@@ -66,20 +66,17 @@ def create_data_loader(args):
min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL:
buffer_size = -1
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
sampler = sampler.sort(key=trg_key).sort(key=src_key)
else:
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
batch_sampler = sampler.batch(
batch_size=args.batch_size,
......@@ -87,12 +84,12 @@ def create_data_loader(args):
batch_size_fn=_max_token_fn,
key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册