提交 d45d17fb 编写于 作者: C chengduozh

support multi-process for bert

上级 7803d896
......@@ -123,7 +123,8 @@ class DataProcessor(object):
phase='train',
epoch=1,
dev_count=1,
shuffle=True):
shuffle=True,
shuffle_seed=None):
"""
Generate data for train, dev or test.
......@@ -149,6 +150,8 @@ class DataProcessor(object):
def instance_reader():
for epoch_index in range(epoch):
if shuffle:
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
......
......@@ -159,15 +159,17 @@ def main(args):
train_program.random_seed = args.random_seed
if args.do_train:
# NOTE: do not shuffle dataset when using multi-process training.
if num_trainers > 1:
args.shuffle = False
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because
# the order of batch data generated by reader
# must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
train_data_generator = processor.data_generator(
batch_size=args.batch_size,
phase='train',
epoch=args.epoch,
dev_count=dev_count,
shuffle=args.shuffle)
shuffle=args.shuffle,
shuffle_seed=shuffle_seed)
num_train_examples = processor.get_num_examples(phase='train')
......@@ -268,6 +270,7 @@ def main(args):
build_strategy = fluid.BuildStrategy()
if args.use_cuda and num_trainers > 1:
assert shuffle_seed is not None
dist_utils.prepare_for_multi_process(exe, build_strategy, train_program)
train_data_generator = fluid.contrib.reader.distributed_batch_reader(
train_data_generator)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册