未验证 提交 62222bf4 编写于 作者: L liym27 提交者: GitHub

Fix test_bert on GPU (#24692)

DataLoader makes the data diff even if the data of reader is the same on CUDA place. This PR doesn't use DataLoader to pass the test. we will use DataLoader back after we fix it.
上级 a6ab43aa
......@@ -260,7 +260,6 @@ class DataReader(object):
yield token_ids, sent_ids, pos_ids, label
def data_generator(self):
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
......@@ -290,8 +289,8 @@ class DataReader(object):
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
for batch_data, total_token_num in batch_reader(reader, self.batch_size,
self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
......@@ -304,8 +303,6 @@ class DataReader(object):
return_max_len=False,
return_num_token=False)
return wrapper
class ModelHyperParams(object):
generate_neg_sample = False
......
......@@ -17,14 +17,17 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
program_translator = ProgramTranslator()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 2020
STEP_NUM = 10
PRINT_STEP = 2
......@@ -35,19 +38,16 @@ def train(bert_config, data_reader):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
data_loader = fluid.io.DataLoader.from_generator(
capacity=50, iterable=True)
data_loader.set_batch_generator(
data_reader.data_generator(), places=place)
bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
step_idx = 0
speed_list = []
for input_data in data_loader():
for input_data in data_reader.data_generator():
input_data = [to_variable(ele) for ele in input_data]
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data
next_sent_acc, mask_lm_loss, total_loss = bert(
src_ids=src_ids,
position_ids=pos_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册