diff --git a/BERT/reader/cls.py b/BERT/reader/cls.py index 2c9bad9ed5cd4a27415ba1c6ef3faeffc35d0b74..767d817ba31fd7df73466919f848d94338e997ff 100644 --- a/BERT/reader/cls.py +++ b/BERT/reader/cls.py @@ -18,7 +18,7 @@ import csv import numpy as np import tokenization from batching import prepare_batch_data -import functools + class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" @@ -178,38 +178,17 @@ class DataProcessor(object): yield batch, total_token_num def wrapper(): - trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + 1 - if trainers_num > 1: - print("start data reader (trainers_num: {}, trainer_id: {})".format( - trainers_num, trainer_id-1)) - get_prepared_batch_input = functools.partial( - self.generate_batch_data, + for batch_data, total_token_num in batch_reader( + instance_reader, batch_size, self.in_tokens): + batch_data = self.generate_batch_data( + batch_data, + total_token_num, voc_size=-1, mask_id=-1, return_input_mask=True, return_max_len=False, return_num_token=False) - - train_data, train_token_num, idx = None, None, 1 - for batch_data, total_token_num in batch_reader( - instance_reader, batch_size, self.in_tokens): - if trainers_num > 1: - if idx < trainers_num: - if idx == trainer_id: - train_data, train_token_num = batch_data, total_token_num - idx += 1 - else: - if idx == trainer_id: - train_data, train_token_num = batch_data, total_token_num - assert train_data is not None, "train data should not be None." - assert train_token_num is not None, "train data should not be None." - batch_data = get_prepared_batch_input(train_data, train_token_num) - yield batch_data - train_data, train_token_num, idx = None, None, 1 - else: - batch_data = get_prepared_batch_input(batch_data, total_token_num) - yield batch_data + yield batch_data return wrapper diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index f2ca68c4728b9c297c1ba25024dba51ab6f7acc5..792707e66fab360d359a418f2261295b883911c6 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -278,7 +278,11 @@ def main(args): exec_strategy=exec_strategy, build_strategy = build_strategy, main_program=train_program) - + num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) + if num_trainers > 1: + train_data_generator = fluid.contrib.reader.multi_process_reader( + train_data_generator) + train_pyreader.decorate_tensor_provider(train_data_generator) else: train_exe = None