From d45d17fbf91285c983af53139402f44df9a1e443 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Fri, 21 Jun 2019 11:48:54 +0800 Subject: [PATCH] support multi-process for bert --- BERT/reader/cls.py | 5 ++++- BERT/run_classifier.py | 11 +++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index 2c75479..7448526 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -123,7 +123,8 @@ class DataProcessor(object): phase='train', epoch=1, dev_count=1, - shuffle=True): + shuffle=True, + shuffle_seed=None): """ Generate data for train, dev or test. @@ -149,6 +150,8 @@ class DataProcessor(object): def instance_reader(): for epoch_index in range(epoch): if shuffle: + if shuffle_seed is not None: + np.random.seed(shuffle_seed) np.random.shuffle(examples) if phase == 'train': self.current_train_epoch = epoch_index diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index 889b25f..4c37b33 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -159,15 +159,17 @@ def main(args): train_program.random_seed = args.random_seed if args.do_train: - # NOTE: do not shuffle dataset when using multi-process training. - if num_trainers > 1: - args.shuffle = False + # NOTE: If num_trainers > 1, the shuffle_seed must be set, because + # the order of batch data generated by reader + # must be the same in the respective processes. + shuffle_seed = 1 if num_trainers > 1 else None train_data_generator = processor.data_generator( batch_size=args.batch_size, phase='train', epoch=args.epoch, dev_count=dev_count, - shuffle=args.shuffle) + shuffle=args.shuffle, + shuffle_seed=shuffle_seed) num_train_examples = processor.get_num_examples(phase='train') @@ -268,6 +270,7 @@ def main(args): build_strategy = fluid.BuildStrategy() if args.use_cuda and num_trainers > 1: + assert shuffle_seed is not None dist_utils.prepare_for_multi_process(exe, build_strategy, train_program) train_data_generator = fluid.contrib.reader.distributed_batch_reader( train_data_generator) -- GitLab