未验证 提交 03f4684d 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Add test for bert pretraining. (#24350)

* [Dy2Stat] Add test for bert pretraining. 

* Construct fake data. 

* Synchronous random seed of program. 
上级 70bc4889
...@@ -175,6 +175,11 @@ class ConcreteProgram(object): ...@@ -175,6 +175,11 @@ class ConcreteProgram(object):
static_func = convert_function_with_cache(dygaph_function) static_func = convert_function_with_cache(dygaph_function)
main_program, start_up = framework.Program(), framework.Program() main_program, start_up = framework.Program(), framework.Program()
# Synchronous random seed of program
main_program.random_seed = framework.default_main_program().random_seed
start_up.random_seed = framework.default_startup_program().random_seed
with framework.program_guard(main_program, start_up): with framework.program_guard(main_program, start_up):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs(main_program) inputs = func_spec.to_static_inputs(main_program)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph import Embedding, Layer, Linear
from paddle.fluid.dygraph.jit import declarative
from transformer_dygraph_model import MultiHeadAttention, PrePostProcessLayer
class PositionwiseFeedForwardLayer(Layer):
def __init__(self,
hidden_act,
d_inner_hid,
d_model,
dropout_rate,
param_initializer=None,
name=""):
super(PositionwiseFeedForwardLayer, self).__init__()
self._i2h = Linear(
input_dim=d_model,
output_dim=d_inner_hid,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0', initializer=param_initializer),
bias_attr=name + '_fc_0.b_0',
act=hidden_act)
self._h2o = Linear(
input_dim=d_inner_hid,
output_dim=d_model,
param_attr=fluid.ParamAttr(
name=name + '_fc_1.w_0', initializer=param_initializer),
bias_attr=name + '_fc_1.b_0')
self._dropout_rate = dropout_rate
def forward(self, x):
hidden = self._i2h(x)
if self._dropout_rate:
hidden = fluid.layers.dropout(
hidden, dropout_prob=self._dropout_rate, is_test=False)
out = self._h2o(hidden)
return out
class EncoderSubLayer(Layer):
def __init__(self,
hidden_act,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=""):
super(EncoderSubLayer, self).__init__()
self.name = name
self._preprocess_cmd = preprocess_cmd
self._postprocess_cmd = postprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
self._preprocess_layer = PrePostProcessLayer(
self._preprocess_cmd, d_model, prepostprocess_dropout)
self._multihead_attention_layer = MultiHeadAttention(
d_key, d_value, d_model, n_head, attention_dropout,
param_initializer)
self._postprocess_layer = PrePostProcessLayer(
self._postprocess_cmd, d_model, self._prepostprocess_dropout)
self._preprocess_layer2 = PrePostProcessLayer(
self._preprocess_cmd, d_model, self._prepostprocess_dropout)
self._positionwise_feed_forward = PositionwiseFeedForwardLayer(
hidden_act,
d_inner_hid,
d_model,
relu_dropout,
param_initializer,
name=name + "_ffn")
self._postprocess_layer2 = PrePostProcessLayer(
self._postprocess_cmd, d_model, self._prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
pre_process_multihead = self._preprocess_layer(enc_input)
attn_output = self._multihead_attention_layer(pre_process_multihead,
None, None, attn_bias)
attn_output = self._postprocess_layer(attn_output, enc_input)
pre_process2_output = self._preprocess_layer2(attn_output)
ffd_output = self._positionwise_feed_forward(pre_process2_output)
return self._postprocess_layer2(ffd_output, attn_output)
class EncoderLayer(Layer):
def __init__(self,
hidden_act,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=""):
super(EncoderLayer, self).__init__()
self._preprocess_cmd = preprocess_cmd
self._encoder_sublayers = list()
self._prepostprocess_dropout = prepostprocess_dropout
self._n_layer = n_layer
self._hidden_act = hidden_act
self._preprocess_layer = PrePostProcessLayer(
self._preprocess_cmd, 3, self._prepostprocess_dropout)
for i in range(n_layer):
self._encoder_sublayers.append(
self.add_sublayer(
'esl_%d' % i,
EncoderSubLayer(
hidden_act,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
param_initializer,
name=name + '_layer_' + str(i))))
def forward(self, enc_input, attn_bias):
for i in range(self._n_layer):
enc_output = self._encoder_sublayers[i](enc_input, attn_bias)
enc_input = enc_output
return self._preprocess_layer(enc_output)
class BertModelLayer(Layer):
def __init__(self, config, return_pooled_out=True, use_fp16=False):
super(BertModelLayer, self).__init__()
self._emb_size = config['hidden_size']
self._n_layer = config['num_hidden_layers']
self._n_head = config['num_attention_heads']
self._voc_size = config['vocab_size']
self._max_position_seq_len = config['max_position_embeddings']
self._sent_types = config['type_vocab_size']
self._hidden_act = config['hidden_act']
self._prepostprocess_dropout = config['hidden_dropout_prob']
self._attention_dropout = config['attention_probs_dropout_prob']
self.return_pooled_out = return_pooled_out
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding"
self._dtype = "float16" if use_fp16 else "float32"
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range'])
self._src_emb = Embedding(
size=[self._voc_size, self._emb_size],
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
dtype=self._dtype)
self._pos_emb = Embedding(
size=[self._max_position_seq_len, self._emb_size],
param_attr=fluid.ParamAttr(
name=self._pos_emb_name, initializer=self._param_initializer),
dtype=self._dtype)
self._sent_emb = Embedding(
size=[self._sent_types, self._emb_size],
param_attr=fluid.ParamAttr(
name=self._sent_emb_name, initializer=self._param_initializer),
dtype=self._dtype)
self.pooled_fc = Linear(
input_dim=self._emb_size,
output_dim=self._emb_size,
param_attr=fluid.ParamAttr(
name="pooled_fc.w_0", initializer=self._param_initializer),
bias_attr="pooled_fc.b_0",
act="tanh")
self.pre_process_layer = PrePostProcessLayer(
"nd", self._emb_size, self._prepostprocess_dropout)
self._encoder = EncoderLayer(
hidden_act=self._hidden_act,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._emb_size * 4,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer)
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
src_emb = self._src_emb(src_ids)
pos_emb = self._pos_emb(position_ids)
sent_emb = self._sent_emb(sentence_ids)
emb_out = src_emb + pos_emb
emb_out = emb_out + sent_emb
emb_out = self.pre_process_layer(emb_out)
self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
enc_output = self._encoder(emb_out, n_head_self_attn_mask)
if not self.return_pooled_out:
return enc_output
next_sent_feat = fluid.layers.slice(
input=enc_output, axes=[1], starts=[0], ends=[1])
next_sent_feat = self.pooled_fc(next_sent_feat)
next_sent_feat = fluid.layers.reshape(
next_sent_feat, shape=[-1, self._emb_size])
return enc_output, next_sent_feat
class PretrainModelLayer(Layer):
def __init__(self,
config,
return_pooled_out=True,
weight_sharing=False,
use_fp16=False):
super(PretrainModelLayer, self).__init__()
self.config = config
self._voc_size = config['vocab_size']
self._emb_size = config['hidden_size']
self._hidden_act = config['hidden_act']
self._prepostprocess_dropout = config['hidden_dropout_prob']
self._word_emb_name = "word_embedding"
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range'])
self._weight_sharing = weight_sharing
self.use_fp16 = use_fp16
self._dtype = "float16" if use_fp16 else "float32"
self.bert_layer = BertModelLayer(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.pre_process_layer = PrePostProcessLayer(
"n", self._emb_size, self._prepostprocess_dropout)
self.pooled_fc = Linear(
input_dim=self._emb_size,
output_dim=self._emb_size,
param_attr=fluid.ParamAttr(
name="mask_lm_trans_fc.w_0",
initializer=self._param_initializer),
bias_attr="mask_lm_trans_fc.b_0",
act="tanh")
self.mask_lm_out_bias_attr = fluid.ParamAttr(
name="mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
if not self._weight_sharing:
self.out_fc = Linear(
input_dim=self._emb_size,
output_dim=self._voc_size,
param_attr=fluid.ParamAttr(
name="mask_lm_out_fc.w_0",
initializer=self._param_initializer),
bias_attr=self.mask_lm_out_bias_attr)
else:
self.fc_create_params = self.create_parameter(
shape=[self._voc_size],
dtype=self._dtype,
attr=self.mask_lm_out_bias_attr,
is_bias=True)
self.next_sent_fc = Linear(
input_dim=self._emb_size,
output_dim=2,
param_attr=fluid.ParamAttr(
name="next_sent_fc.w_0", initializer=self._param_initializer),
bias_attr="next_sent_fc.b_0")
@declarative
def forward(self, src_ids, position_ids, sentence_ids, input_mask,
mask_label, mask_pos, labels):
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
reshaped_emb_out = fluid.layers.reshape(
x=enc_output, shape=[-1, self._emb_size])
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
mask_trans_feat = self.pooled_fc(mask_feat)
mask_trans_feat = self.pre_process_layer(mask_trans_feat)
if self._weight_sharing:
fc_out = fluid.layers.matmul(
x=mask_trans_feat,
y=self.bert_layer._src_emb._w,
transpose_y=True)
fc_out += self.fc_create_params
else:
fc_out = self.out_fc(mask_trans_feat)
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
next_sent_fc_out = self.next_sent_fc(next_sent_feat)
next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy(
logits=next_sent_fc_out, label=labels, return_softmax=True)
next_sent_acc = fluid.layers.accuracy(
input=next_sent_softmax, label=labels)
mean_next_sent_loss = fluid.layers.mean(next_sent_loss)
loss = mean_next_sent_loss + mean_mask_lm_loss
return next_sent_acc, mean_mask_lm_loss, loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import numpy as np
import random
SEED = 2020
def get_bert_config():
bert_config = {
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 2,
"initializer_range": 0.02,
"intermediate_size": 72,
"max_position_embeddings": 512,
"num_attention_heads": 2,
"num_hidden_layers": 2,
"pooler_fc_size": 2,
"pooler_num_attention_heads": 2,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 8,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": 21128
}
return bert_config
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
np.random.seed(SEED)
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def pad_batch_data(insts,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_data(insts,
total_token_num,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
1. generate Tensor of data
2. generate Tensor of position
3. generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_src_ids = [inst[0] for inst in insts]
batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts]
labels_list = []
for i in range(3, len(insts[0]), 1):
labels = [inst[i] for inst in insts]
labels = np.array(labels).astype("int64").reshape([-1, 1])
labels_list.append(labels)
# First step: do mask without padding
if mask_id >= 0:
out, mask_label, mask_pos = mask(
batch_src_ids,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
else:
out = batch_src_ids
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out, pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
if mask_id >= 0:
return_list = [
src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos
] + labels_list
else:
return_list = [src_id, pos_id, sent_id, self_input_mask] + labels_list
res = return_list if len(return_list) > 1 else return_list[0]
return res
class DataReader(object):
def __init__(self,
batch_size=4096,
in_tokens=True,
max_seq_len=512,
shuffle_files=False,
epoch=100,
voc_size=0,
is_test=False,
generate_neg_sample=False):
self.batch_size = batch_size
self.in_tokens = in_tokens
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.current_file = None
self.voc_size = voc_size
self.max_seq_len = max_seq_len
self.pad_id = 0
self.cls_id = 101
self.sep_id = 102
self.mask_id = 103
self.is_test = is_test
self.generate_neg_sample = generate_neg_sample
if self.in_tokens:
assert self.batch_size >= self.max_seq_len, "The number of " \
"tokens in batch should not be smaller than max seq length."
if self.is_test:
self.epoch = 1
self.shuffle_files = False
def build_fake_data(self):
for _ in range(1000000):
random.seed(SEED)
sent0_len = random.randint(50, 100)
sent1_len = random.randint(50, 100)
token_ids = [1] \
+ [random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [2]
sent_ids = [0 for i in range(sent0_len)
] + [1 for i in range(sent1_len)]
pos_ids = [i for i in range(sent0_len + sent1_len)]
label = 1
yield token_ids, sent_ids, pos_ids, label
def data_generator(self):
def wrapper():
def reader():
for epoch in range(self.epoch):
self.current_epoch = epoch + 1
sample_generator = self.build_fake_data()
for sample in sample_generator:
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size, in_tokens):
batch, total_token_num, max_len = [], 0, 0
for parsed_line in reader():
token_ids, sent_ids, pos_ids, label = parsed_line
max_len = max(max_len, len(token_ids))
if in_tokens:
to_append = (len(batch) + 1) * max_len <= batch_size
else:
to_append = len(batch) < batch_size
if to_append:
batch.append(parsed_line)
total_token_num += len(token_ids)
else:
yield batch, total_token_num
batch, total_token_num, max_len = [parsed_line], len(
token_ids), len(token_ids)
if len(batch) > 0:
yield batch, total_token_num
for batch_data, total_token_num in batch_reader(
reader, self.batch_size, self.in_tokens):
yield prepare_batch_data(
batch_data,
total_token_num,
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
class ModelHyperParams(object):
generate_neg_sample = False
epoch = 100
max_seq_len = 512
batch_size = 8192
in_tokens = True
def get_feed_data_reader(bert_config):
args = ModelHyperParams()
data_reader = DataReader(
batch_size=args.batch_size,
in_tokens=args.in_tokens,
voc_size=bert_config['vocab_size'],
epoch=args.epoch,
max_seq_len=args.max_seq_len,
generate_neg_sample=args.generate_neg_sample)
return data_reader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
program_translator = ProgramTranslator()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 2020
STEP_NUM = 10
PRINT_STEP = 2
def train(bert_config, data_reader):
with fluid.dygraph.guard(place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
data_loader = fluid.io.DataLoader.from_generator(
capacity=50, iterable=True)
data_loader.set_batch_generator(
data_reader.data_generator(), places=place)
bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
step_idx = 0
speed_list = []
for input_data in data_loader():
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_data
next_sent_acc, mask_lm_loss, total_loss = bert(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
input_mask=input_mask,
mask_label=mask_label,
mask_pos=mask_pos,
labels=labels)
total_loss.backward()
optimizer.minimize(total_loss)
bert.clear_gradients()
acc = np.mean(np.array(next_sent_acc.numpy()))
loss = np.mean(np.array(total_loss.numpy()))
ppl = np.mean(np.exp(np.array(mask_lm_loss.numpy())))
if step_idx % PRINT_STEP == 0:
if step_idx == 0:
print("Step: %d, loss: %f, ppl: %f, next_sent_acc: %f" %
(step_idx, loss, ppl, acc))
avg_batch_time = time.time()
else:
speed = PRINT_STEP / (time.time() - avg_batch_time)
speed_list.append(speed)
print(
"Step: %d, loss: %f, ppl: %f, next_sent_acc: %f, speed: %.3f steps/s"
% (step_idx, loss, ppl, acc, speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == STEP_NUM:
break
return loss, ppl
def train_dygraph(bert_config, data_reader):
program_translator.enable(False)
return train(bert_config, data_reader)
def train_static(bert_config, data_reader):
program_translator.enable(True)
return train(bert_config, data_reader)
class TestBert(unittest.TestCase):
def setUp(self):
self.bert_config = get_bert_config()
self.data_reader = get_feed_data_reader(self.bert_config)
def test_train(self):
static_loss, static_ppl = train_static(self.bert_config,
self.data_reader)
dygraph_loss, dygraph_ppl = train_dygraph(self.bert_config,
self.data_reader)
self.assertTrue(
np.allclose(static_loss, static_loss),
msg="static_loss: {} \n static_loss: {}".format(static_loss,
dygraph_loss))
self.assertTrue(
np.allclose(static_ppl, dygraph_ppl),
msg="static_ppl: {} \n dygraph_ppl: {}".format(static_ppl,
dygraph_ppl))
if __name__ == '__main__':
unittest.main()
...@@ -76,7 +76,13 @@ class PrePostProcessLayer(Layer): ...@@ -76,7 +76,13 @@ class PrePostProcessLayer(Layer):
class MultiHeadAttention(Layer): class MultiHeadAttention(Layer):
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): def __init__(self,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
param_initializer=None):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.n_head = n_head self.n_head = n_head
self.d_key = d_key self.d_key = d_key
...@@ -84,13 +90,25 @@ class MultiHeadAttention(Layer): ...@@ -84,13 +90,25 @@ class MultiHeadAttention(Layer):
self.d_model = d_model self.d_model = d_model
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.q_fc = Linear( self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False,
param_attr=fluid.ParamAttr(initializer=param_initializer))
self.k_fc = Linear( self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False,
param_attr=fluid.ParamAttr(initializer=param_initializer))
self.v_fc = Linear( self.v_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False) input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False,
param_attr=fluid.ParamAttr(initializer=param_initializer))
self.proj_fc = Linear( self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False) input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False,
param_attr=fluid.ParamAttr(initializer=param_initializer))
def forward(self, queries, keys, values, attn_bias, cache=None): def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v # compute q ,k ,v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册