提交 0f65b4dd 编写于 作者: C chengduozh

support multi process for bert

上级 09d106e1
...@@ -18,7 +18,7 @@ import csv ...@@ -18,7 +18,7 @@ import csv
import numpy as np import numpy as np
import tokenization import tokenization
from batching import prepare_batch_data from batching import prepare_batch_data
import functools
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
...@@ -178,38 +178,17 @@ class DataProcessor(object): ...@@ -178,38 +178,17 @@ class DataProcessor(object):
yield batch, total_token_num yield batch, total_token_num
def wrapper(): def wrapper():
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) for batch_data, total_token_num in batch_reader(
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + 1 instance_reader, batch_size, self.in_tokens):
if trainers_num > 1: batch_data = self.generate_batch_data(
print("start data reader (trainers_num: {}, trainer_id: {})".format( batch_data,
trainers_num, trainer_id-1)) total_token_num,
get_prepared_batch_input = functools.partial(
self.generate_batch_data,
voc_size=-1, voc_size=-1,
mask_id=-1, mask_id=-1,
return_input_mask=True, return_input_mask=True,
return_max_len=False, return_max_len=False,
return_num_token=False) return_num_token=False)
yield batch_data
train_data, train_token_num, idx = None, None, 1
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
if trainers_num > 1:
if idx < trainers_num:
if idx == trainer_id:
train_data, train_token_num = batch_data, total_token_num
idx += 1
else:
if idx == trainer_id:
train_data, train_token_num = batch_data, total_token_num
assert train_data is not None, "train data should not be None."
assert train_token_num is not None, "train data should not be None."
batch_data = get_prepared_batch_input(train_data, train_token_num)
yield batch_data
train_data, train_token_num, idx = None, None, 1
else:
batch_data = get_prepared_batch_input(batch_data, total_token_num)
yield batch_data
return wrapper return wrapper
......
...@@ -278,7 +278,11 @@ def main(args): ...@@ -278,7 +278,11 @@ def main(args):
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
build_strategy = build_strategy, build_strategy = build_strategy,
main_program=train_program) main_program=train_program)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers > 1:
train_data_generator = fluid.contrib.reader.multi_process_reader(
train_data_generator)
train_pyreader.decorate_tensor_provider(train_data_generator) train_pyreader.decorate_tensor_provider(train_data_generator)
else: else:
train_exe = None train_exe = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册