未验证 提交 e251df83 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] Fix random bug in bert unittest (#24727)

* Revert "Fix test_bert on GPU (#24692)"

This reverts commit 62222bf4.

* fix random bug in bert unittest, test=develop
上级 de8b4f42
......@@ -49,10 +49,13 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
np.random.seed(SEED)
prob_mask = np.random.rand(total_token_num)
# NOTE: numpy random is not thread-safe, for async DataLoader,
# using np.random.seed() directly is risky, using RandomState
# class is a better way
self_random = np.random.RandomState(SEED)
prob_mask = self_random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
replace_ids = self_random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
......@@ -85,7 +88,9 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
token_index = int(
self_random.randint(
1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
......@@ -244,13 +249,16 @@ class DataReader(object):
def build_fake_data(self):
for _ in range(1000000):
random.seed(SEED)
sent0_len = random.randint(50, 100)
sent1_len = random.randint(50, 100)
# NOTE: python random has bug in python2,
# we should avoid using random module,
# please using numpy.random
self_random = np.random.RandomState(SEED)
sent0_len = self_random.randint(50, 100)
sent1_len = self_random.randint(50, 100)
token_ids = [1] \
+ [random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [self_random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [self_random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [2]
sent_ids = [0 for i in range(sent0_len)
......@@ -260,48 +268,51 @@ class DataReader(object):
yield token_ids, sent_ids, pos_ids, label
def data_generator(self):
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
sample_generator = self.build_fake_data()
for sample in sample_generator:
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
sample_generator = self.build_fake_data()
for sample in sample_generator:
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(reader, self.batch_size,
self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
class ModelHyperParams(object):
......
......@@ -17,17 +17,14 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
program_translator = ProgramTranslator()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 2020
STEP_NUM = 10
PRINT_STEP = 2
......@@ -38,16 +35,19 @@ def train(bert_config, data_reader):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
data_loader = fluid.io.DataLoader.from_generator(
capacity=50, iterable=True)
data_loader.set_batch_generator(
data_reader.data_generator(), places=place)
bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
step_idx = 0
speed_list = []
for input_data in data_reader.data_generator():
input_data = [to_variable(ele) for ele in input_data]
for input_data in data_loader():
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data
next_sent_acc, mask_lm_loss, total_loss = bert(
src_ids=src_ids,
position_ids=pos_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册