From ea70ced2cfac377dc67ff200e07e23aa455dc420 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Tue, 22 Dec 2020 12:38:14 +0800 Subject: [PATCH] update reader (#5129) * shuffle batch * update reader * update benchmark --- PaddleNLP/benchmark/transformer/reader.py | 17 +++++++---------- .../machine_translation/transformer/reader.py | 17 +++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/PaddleNLP/benchmark/transformer/reader.py b/PaddleNLP/benchmark/transformer/reader.py index 9e3f86ee..9d295a94 100644 --- a/PaddleNLP/benchmark/transformer/reader.py +++ b/PaddleNLP/benchmark/transformer/reader.py @@ -66,20 +66,17 @@ def create_data_loader(args): min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) - src_key = (lambda x, data_source: len(data_source[x][0]) + 1) if args.sort_type == SortType.GLOBAL: - buffer_size = -1 + src_key = (lambda x, data_source: len(data_source[x][0]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) # Sort twice - sampler = sampler.sort( - key=trg_key, buffer_size=buffer_size).sort( - key=src_key, buffer_size=buffer_size) + sampler = sampler.sort(key=trg_key).sort(key=src_key) else: if args.shuffle: sampler = sampler.shuffle(seed=shuffle_seed) + max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1) if args.sort_type == SortType.POOL: - buffer_size = args.pool_size - sampler = sampler.sort(key=src_key, buffer_size=buffer_size) + sampler = sampler.sort(key=max_key, buffer_size=args.pool_size) batch_sampler = sampler.batch( batch_size=args.batch_size, @@ -87,12 +84,12 @@ def create_data_loader(args): batch_size_fn=_max_token_fn, key=_key) - if m == "train": - batch_sampler = batch_sampler.shard() - if args.shuffle_batch: batch_sampler.shuffle(seed=shuffle_seed) + if m == "train": + batch_sampler = batch_sampler.shard() + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, diff --git a/PaddleNLP/examples/machine_translation/transformer/reader.py b/PaddleNLP/examples/machine_translation/transformer/reader.py index 9e3f86ee..9d295a94 100644 --- a/PaddleNLP/examples/machine_translation/transformer/reader.py +++ b/PaddleNLP/examples/machine_translation/transformer/reader.py @@ -66,20 +66,17 @@ def create_data_loader(args): min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) - src_key = (lambda x, data_source: len(data_source[x][0]) + 1) if args.sort_type == SortType.GLOBAL: - buffer_size = -1 + src_key = (lambda x, data_source: len(data_source[x][0]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) # Sort twice - sampler = sampler.sort( - key=trg_key, buffer_size=buffer_size).sort( - key=src_key, buffer_size=buffer_size) + sampler = sampler.sort(key=trg_key).sort(key=src_key) else: if args.shuffle: sampler = sampler.shuffle(seed=shuffle_seed) + max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1) if args.sort_type == SortType.POOL: - buffer_size = args.pool_size - sampler = sampler.sort(key=src_key, buffer_size=buffer_size) + sampler = sampler.sort(key=max_key, buffer_size=args.pool_size) batch_sampler = sampler.batch( batch_size=args.batch_size, @@ -87,12 +84,12 @@ def create_data_loader(args): batch_size_fn=_max_token_fn, key=_key) - if m == "train": - batch_sampler = batch_sampler.shard() - if args.shuffle_batch: batch_sampler.shuffle(seed=shuffle_seed) + if m == "train": + batch_sampler = batch_sampler.shard() + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, -- GitLab