diff --git a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml index 05a6520110c6fdcde776935ea4bd31d539d644ed..fa321f16055e17ec31458fdd1d3d956b4f97338d 100644 --- a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml +++ b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 16 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/benchmark/transformer/reader.py b/PaddleNLP/benchmark/transformer/reader.py index 38fcda422fa12504b63808e6791fa236f5ef098e..9e3f86ee0ef91fab22ab2b47bbd60a5563d7b7bb 100644 --- a/PaddleNLP/benchmark/transformer/reader.py +++ b/PaddleNLP/benchmark/transformer/reader.py @@ -43,6 +43,12 @@ def create_data_loader(args): mode=m, transform_func=transform_func) for m in ["train", "dev"] ] + if args.shuffle or args.shuffle_batch: + if args.shuffle_seed == "None" or args.shuffle_seed is None: + shuffle_seed = 0 + else: + shuffle_seed = args.shuffle_seed + def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, @@ -69,7 +75,8 @@ def create_data_loader(args): key=trg_key, buffer_size=buffer_size).sort( key=src_key, buffer_size=buffer_size) else: - sampler = sampler.shuffle() + if args.shuffle: + sampler = sampler.shuffle(seed=shuffle_seed) if args.sort_type == SortType.POOL: buffer_size = args.pool_size sampler = sampler.sort(key=src_key, buffer_size=buffer_size) @@ -83,6 +90,9 @@ def create_data_loader(args): if m == "train": batch_sampler = batch_sampler.shard() + if args.shuffle_batch: + batch_sampler.shuffle(seed=shuffle_seed) + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, diff --git a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml index 7ea9ebbe7141963a2934825e3c9200a1ec40d3f2..57070dc2966cb00c0443a152d4533ce700291017 100644 --- a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml +++ b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 8 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml index d458f4c7eb15124c3bb82fb8858528563a32eba2..4cd3b1201f5b2a1c16b7e93c83d0a071078b4913 100644 --- a/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml +++ b/PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml @@ -27,6 +27,12 @@ pool_size: 200000 sort_type: "global" batch_size: 4096 infer_batch_size: 8 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 # Hyparams for training: # The number of epoches for training diff --git a/PaddleNLP/examples/machine_translation/transformer/reader.py b/PaddleNLP/examples/machine_translation/transformer/reader.py index 38fcda422fa12504b63808e6791fa236f5ef098e..9e3f86ee0ef91fab22ab2b47bbd60a5563d7b7bb 100644 --- a/PaddleNLP/examples/machine_translation/transformer/reader.py +++ b/PaddleNLP/examples/machine_translation/transformer/reader.py @@ -43,6 +43,12 @@ def create_data_loader(args): mode=m, transform_func=transform_func) for m in ["train", "dev"] ] + if args.shuffle or args.shuffle_batch: + if args.shuffle_seed == "None" or args.shuffle_seed is None: + shuffle_seed = 0 + else: + shuffle_seed = args.shuffle_seed + def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, @@ -69,7 +75,8 @@ def create_data_loader(args): key=trg_key, buffer_size=buffer_size).sort( key=src_key, buffer_size=buffer_size) else: - sampler = sampler.shuffle() + if args.shuffle: + sampler = sampler.shuffle(seed=shuffle_seed) if args.sort_type == SortType.POOL: buffer_size = args.pool_size sampler = sampler.sort(key=src_key, buffer_size=buffer_size) @@ -83,6 +90,9 @@ def create_data_loader(args): if m == "train": batch_sampler = batch_sampler.shard() + if args.shuffle_batch: + batch_sampler.shuffle(seed=shuffle_seed) + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler,