From 09d106e10df81acb2c81653bb246ee5e04a68263 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Sat, 15 Jun 2019 00:07:52 +0800 Subject: [PATCH] support multi process reader for bert --- BERT/reader/cls.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index 767d817..2c9bad9 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -18,7 +18,7 @@ import csv import numpy as np import tokenization from batching import prepare_batch_data - +import functools class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" @@ -178,17 +178,38 @@ class DataProcessor(object): yield batch, total_token_num def wrapper(): - for batch_data, total_token_num in batch_reader( - instance_reader, batch_size, self.in_tokens): - batch_data = self.generate_batch_data( - batch_data, - total_token_num, + trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + 1 + if trainers_num > 1: + print("start data reader (trainers_num: {}, trainer_id: {})".format( + trainers_num, trainer_id-1)) + get_prepared_batch_input = functools.partial( + self.generate_batch_data, voc_size=-1, mask_id=-1, return_input_mask=True, return_max_len=False, return_num_token=False) - yield batch_data + + train_data, train_token_num, idx = None, None, 1 + for batch_data, total_token_num in batch_reader( + instance_reader, batch_size, self.in_tokens): + if trainers_num > 1: + if idx < trainers_num: + if idx == trainer_id: + train_data, train_token_num = batch_data, total_token_num + idx += 1 + else: + if idx == trainer_id: + train_data, train_token_num = batch_data, total_token_num + assert train_data is not None, "train data should not be None." + assert train_token_num is not None, "train data should not be None." + batch_data = get_prepared_batch_input(train_data, train_token_num) + yield batch_data + train_data, train_token_num, idx = None, None, 1 + else: + batch_data = get_prepared_batch_input(batch_data, total_token_num) + yield batch_data return wrapper -- GitLab