提交 09d106e1 编写于 作者: C chengduozh

support multi process reader for bert

上级 ad3547c0
......@@ -18,7 +18,7 @@ import csv
import numpy as np
import tokenization
from batching import prepare_batch_data
import functools
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
......@@ -178,17 +178,38 @@ class DataProcessor(object):
yield batch, total_token_num
def wrapper():
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
batch_data = self.generate_batch_data(
batch_data,
total_token_num,
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + 1
if trainers_num > 1:
print("start data reader (trainers_num: {}, trainer_id: {})".format(
trainers_num, trainer_id-1))
get_prepared_batch_input = functools.partial(
self.generate_batch_data,
voc_size=-1,
mask_id=-1,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
yield batch_data
train_data, train_token_num, idx = None, None, 1
for batch_data, total_token_num in batch_reader(
instance_reader, batch_size, self.in_tokens):
if trainers_num > 1:
if idx < trainers_num:
if idx == trainer_id:
train_data, train_token_num = batch_data, total_token_num
idx += 1
else:
if idx == trainer_id:
train_data, train_token_num = batch_data, total_token_num
assert train_data is not None, "train data should not be None."
assert train_token_num is not None, "train data should not be None."
batch_data = get_prepared_batch_input(train_data, train_token_num)
yield batch_data
train_data, train_token_num, idx = None, None, 1
else:
batch_data = get_prepared_batch_input(batch_data, total_token_num)
yield batch_data
return wrapper
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册