未验证 提交 c210e7ed 编写于 作者: L liu zhengxi 提交者: GitHub

update reader (#5133)

* cherry-pick #5129 
上级 63b738e3
...@@ -66,20 +66,17 @@ def create_data_loader(args): ...@@ -66,20 +66,17 @@ def create_data_loader(args):
min_max_filer, max_len=args.max_length)) min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset) sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL: if args.sort_type == SortType.GLOBAL:
buffer_size = -1 src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice # Sort twice
sampler = sampler.sort( sampler = sampler.sort(key=trg_key).sort(key=src_key)
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else: else:
if args.shuffle: if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed) sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL: if args.sort_type == SortType.POOL:
buffer_size = args.pool_size sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch( batch_sampler = sampler.batch(
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -87,12 +84,12 @@ def create_data_loader(args): ...@@ -87,12 +84,12 @@ def create_data_loader(args):
batch_size_fn=_max_token_fn, batch_size_fn=_max_token_fn,
key=_key) key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch: if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed) batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -66,20 +66,17 @@ def create_data_loader(args): ...@@ -66,20 +66,17 @@ def create_data_loader(args):
min_max_filer, max_len=args.max_length)) min_max_filer, max_len=args.max_length))
sampler = SamplerHelper(dataset) sampler = SamplerHelper(dataset)
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
if args.sort_type == SortType.GLOBAL: if args.sort_type == SortType.GLOBAL:
buffer_size = -1 src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice # Sort twice
sampler = sampler.sort( sampler = sampler.sort(key=trg_key).sort(key=src_key)
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else: else:
if args.shuffle: if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed) sampler = sampler.shuffle(seed=shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL: if args.sort_type == SortType.POOL:
buffer_size = args.pool_size sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
batch_sampler = sampler.batch( batch_sampler = sampler.batch(
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -87,12 +84,12 @@ def create_data_loader(args): ...@@ -87,12 +84,12 @@ def create_data_loader(args):
batch_size_fn=_max_token_fn, batch_size_fn=_max_token_fn,
key=_key) key=_key)
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch: if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed) batch_sampler.shuffle(seed=shuffle_seed)
if m == "train":
batch_sampler = batch_sampler.shard()
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册