未验证 提交 e251df83 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] Fix random bug in bert unittest (#24727)

* Revert "Fix test_bert on GPU (#24692)"

This reverts commit 62222bf4.

* fix random bug in bert unittest, test=develop
上级 de8b4f42
...@@ -49,10 +49,13 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -49,10 +49,13 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
max_len = max([len(sent) for sent in batch_tokens]) max_len = max([len(sent) for sent in batch_tokens])
mask_label = [] mask_label = []
mask_pos = [] mask_pos = []
np.random.seed(SEED) # NOTE: numpy random is not thread-safe, for async DataLoader,
prob_mask = np.random.rand(total_token_num) # using np.random.seed() directly is risky, using RandomState
# class is a better way
self_random = np.random.RandomState(SEED)
prob_mask = self_random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1] # Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num) replace_ids = self_random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0 pre_sent_len = 0
prob_index = 0 prob_index = 0
for sent_index, sent in enumerate(batch_tokens): for sent_index, sent in enumerate(batch_tokens):
...@@ -85,7 +88,9 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -85,7 +88,9 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
# ensure at least mask one word in a sentence # ensure at least mask one word in a sentence
while not mask_flag: while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1)) token_index = int(
self_random.randint(
1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS: if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index]) mask_label.append(sent[token_index])
sent[token_index] = MASK sent[token_index] = MASK
...@@ -244,13 +249,16 @@ class DataReader(object): ...@@ -244,13 +249,16 @@ class DataReader(object):
def build_fake_data(self): def build_fake_data(self):
for _ in range(1000000): for _ in range(1000000):
random.seed(SEED) # NOTE: python random has bug in python2,
sent0_len = random.randint(50, 100) # we should avoid using random module,
sent1_len = random.randint(50, 100) # please using numpy.random
self_random = np.random.RandomState(SEED)
sent0_len = self_random.randint(50, 100)
sent1_len = self_random.randint(50, 100)
token_ids = [1] \ token_ids = [1] \
+ [random.randint(0, 10000) for i in range(sent0_len-1)] \ + [self_random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [random.randint(0, 10000) for i in range(sent1_len-1)] \ + [self_random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [2] + [2]
sent_ids = [0 for i in range(sent0_len) sent_ids = [0 for i in range(sent0_len)
...@@ -260,48 +268,51 @@ class DataReader(object): ...@@ -260,48 +268,51 @@ class DataReader(object):
yield token_ids, sent_ids, pos_ids, label yield token_ids, sent_ids, pos_ids, label
def data_generator(self): def data_generator(self):
def reader(): def wrapper():
for epoch in range(self.epoch): def reader():
self.current_epoch = epoch + 1 for epoch in range(self.epoch):
sample_generator = self.build_fake_data() self.current_epoch = epoch + 1
for sample in sample_generator: sample_generator = self.build_fake_data()
if sample is None: for sample in sample_generator:
continue if sample is None:
yield sample continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0 def batch_reader(reader, batch_size, in_tokens):
for parsed_line in reader(): batch, total_token_num, max_len = [], 0, 0
token_ids, sent_ids, pos_ids, label = parsed_line for parsed_line in reader():
max_len = max(max_len, len(token_ids)) token_ids, sent_ids, pos_ids, label = parsed_line
if in_tokens: max_len = max(max_len, len(token_ids))
to_append = (len(batch) + 1) * max_len <= batch_size if in_tokens:
else: to_append = (len(batch) + 1) * max_len <= batch_size
to_append = len(batch) < batch_size else:
if to_append: to_append = len(batch) < batch_size
batch.append(parsed_line) if to_append:
total_token_num += len(token_ids) batch.append(parsed_line)
else: total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids) for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
if len(batch) > 0: yield prepare_batch_data(
yield batch, total_token_num batch_data,
total_token_num,
for batch_data, total_token_num in batch_reader(reader, self.batch_size, voc_size=self.voc_size,
self.in_tokens): pad_id=self.pad_id,
yield prepare_batch_data( cls_id=self.cls_id,
batch_data, sep_id=self.sep_id,
total_token_num, mask_id=self.mask_id,
voc_size=self.voc_size, return_input_mask=True,
pad_id=self.pad_id, return_max_len=False,
cls_id=self.cls_id, return_num_token=False)
sep_id=self.sep_id,
mask_id=self.mask_id, return wrapper
return_input_mask=True,
return_max_len=False,
return_num_token=False)
class ModelHyperParams(object): class ModelHyperParams(object):
......
...@@ -17,17 +17,14 @@ import unittest ...@@ -17,17 +17,14 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from bert_dygraph_model import PretrainModelLayer from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader from bert_utils import get_bert_config, get_feed_data_reader
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
) )
SEED = 2020 SEED = 2020
STEP_NUM = 10 STEP_NUM = 10
PRINT_STEP = 2 PRINT_STEP = 2
...@@ -38,16 +35,19 @@ def train(bert_config, data_reader): ...@@ -38,16 +35,19 @@ def train(bert_config, data_reader):
fluid.default_main_program().random_seed = SEED fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED
data_loader = fluid.io.DataLoader.from_generator(
capacity=50, iterable=True)
data_loader.set_batch_generator(
data_reader.data_generator(), places=place)
bert = PretrainModelLayer( bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False) config=bert_config, weight_sharing=False, use_fp16=False)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
step_idx = 0 step_idx = 0
speed_list = [] speed_list = []
for input_data in data_reader.data_generator(): for input_data in data_loader():
input_data = [to_variable(ele) for ele in input_data]
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data
next_sent_acc, mask_lm_loss, total_loss = bert( next_sent_acc, mask_lm_loss, total_loss = bert(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册