提交 d45d17fb 编写于 作者: C chengduozh

support multi-process for bert

上级 7803d896
...@@ -123,7 +123,8 @@ class DataProcessor(object): ...@@ -123,7 +123,8 @@ class DataProcessor(object):
phase='train', phase='train',
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=True): shuffle=True,
shuffle_seed=None):
""" """
Generate data for train, dev or test. Generate data for train, dev or test.
...@@ -149,6 +150,8 @@ class DataProcessor(object): ...@@ -149,6 +150,8 @@ class DataProcessor(object):
def instance_reader(): def instance_reader():
for epoch_index in range(epoch): for epoch_index in range(epoch):
if shuffle: if shuffle:
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(examples) np.random.shuffle(examples)
if phase == 'train': if phase == 'train':
self.current_train_epoch = epoch_index self.current_train_epoch = epoch_index
......
...@@ -159,15 +159,17 @@ def main(args): ...@@ -159,15 +159,17 @@ def main(args):
train_program.random_seed = args.random_seed train_program.random_seed = args.random_seed
if args.do_train: if args.do_train:
# NOTE: do not shuffle dataset when using multi-process training. # NOTE: If num_trainers > 1, the shuffle_seed must be set, because
if num_trainers > 1: # the order of batch data generated by reader
args.shuffle = False # must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
train_data_generator = processor.data_generator( train_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size,
phase='train', phase='train',
epoch=args.epoch, epoch=args.epoch,
dev_count=dev_count, dev_count=dev_count,
shuffle=args.shuffle) shuffle=args.shuffle,
shuffle_seed=shuffle_seed)
num_train_examples = processor.get_num_examples(phase='train') num_train_examples = processor.get_num_examples(phase='train')
...@@ -268,6 +270,7 @@ def main(args): ...@@ -268,6 +270,7 @@ def main(args):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if args.use_cuda and num_trainers > 1: if args.use_cuda and num_trainers > 1:
assert shuffle_seed is not None
dist_utils.prepare_for_multi_process(exe, build_strategy, train_program) dist_utils.prepare_for_multi_process(exe, build_strategy, train_program)
train_data_generator = fluid.contrib.reader.distributed_batch_reader( train_data_generator = fluid.contrib.reader.distributed_batch_reader(
train_data_generator) train_data_generator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册