diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_utils.py index 75ad7ae077fd8e1f9952561fab3743172ce859fe..6d18a34c8a8f5854b0b5106b1e1e38483f2d0109 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/bert_utils.py @@ -260,51 +260,48 @@ class DataReader(object): yield token_ids, sent_ids, pos_ids, label def data_generator(self): - def wrapper(): - def reader(): - for epoch in range(self.epoch): - self.current_epoch = epoch + 1 - sample_generator = self.build_fake_data() - for sample in sample_generator: - if sample is None: - continue - yield sample - - def batch_reader(reader, batch_size, in_tokens): - batch, total_token_num, max_len = [], 0, 0 - for parsed_line in reader(): - token_ids, sent_ids, pos_ids, label = parsed_line - max_len = max(max_len, len(token_ids)) - if in_tokens: - to_append = (len(batch) + 1) * max_len <= batch_size - else: - to_append = len(batch) < batch_size - if to_append: - batch.append(parsed_line) - total_token_num += len(token_ids) - else: - yield batch, total_token_num - batch, total_token_num, max_len = [parsed_line], len( - token_ids), len(token_ids) - - if len(batch) > 0: + def reader(): + for epoch in range(self.epoch): + self.current_epoch = epoch + 1 + sample_generator = self.build_fake_data() + for sample in sample_generator: + if sample is None: + continue + yield sample + + def batch_reader(reader, batch_size, in_tokens): + batch, total_token_num, max_len = [], 0, 0 + for parsed_line in reader(): + token_ids, sent_ids, pos_ids, label = parsed_line + max_len = max(max_len, len(token_ids)) + if in_tokens: + to_append = (len(batch) + 1) * max_len <= batch_size + else: + to_append = len(batch) < batch_size + if to_append: + batch.append(parsed_line) + total_token_num += len(token_ids) + else: yield batch, total_token_num - - for batch_data, total_token_num in batch_reader( - reader, self.batch_size, self.in_tokens): - yield prepare_batch_data( - batch_data, - total_token_num, - voc_size=self.voc_size, - pad_id=self.pad_id, - cls_id=self.cls_id, - sep_id=self.sep_id, - mask_id=self.mask_id, - return_input_mask=True, - return_max_len=False, - return_num_token=False) - - return wrapper + batch, total_token_num, max_len = [parsed_line], len( + token_ids), len(token_ids) + + if len(batch) > 0: + yield batch, total_token_num + + for batch_data, total_token_num in batch_reader(reader, self.batch_size, + self.in_tokens): + yield prepare_batch_data( + batch_data, + total_token_num, + voc_size=self.voc_size, + pad_id=self.pad_id, + cls_id=self.cls_id, + sep_id=self.sep_id, + mask_id=self.mask_id, + return_input_mask=True, + return_max_len=False, + return_num_token=False) class ModelHyperParams(object): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py index 8c7e0d9075ce9a10f0ce9fa98397b7d016581d97..307b9736dfb4986e571347d26fadfd17532e13c7 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py @@ -17,14 +17,17 @@ import unittest import numpy as np import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader program_translator = ProgramTranslator() + place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ) + SEED = 2020 STEP_NUM = 10 PRINT_STEP = 2 @@ -35,19 +38,16 @@ def train(bert_config, data_reader): fluid.default_main_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED - data_loader = fluid.io.DataLoader.from_generator( - capacity=50, iterable=True) - data_loader.set_batch_generator( - data_reader.data_generator(), places=place) - bert = PretrainModelLayer( config=bert_config, weight_sharing=False, use_fp16=False) optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) step_idx = 0 speed_list = [] - for input_data in data_loader(): + for input_data in data_reader.data_generator(): + input_data = [to_variable(ele) for ele in input_data] src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data + next_sent_acc, mask_lm_loss, total_loss = bert( src_ids=src_ids, position_ids=pos_ids,