未验证 提交 f14db82d 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #783 from guoshengCS/fix-transformer-batchsize

Decouple the program desc with batch_size in Transformer.
...@@ -25,8 +25,7 @@ class TrainTaskConfig(object): ...@@ -25,8 +25,7 @@ class TrainTaskConfig(object):
class InferTaskConfig(object): class InferTaskConfig(object):
use_gpu = False use_gpu = False
# the number of examples in one run for sequence generation. # the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1. batch_size = 10
batch_size = 1
# the parameters for beam search. # the parameters for beam search.
beam_size = 5 beam_size = 5
...@@ -103,6 +102,7 @@ encoder_input_data_names = ( ...@@ -103,6 +102,7 @@ encoder_input_data_names = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", "src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape", "src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", ) "src_slf_attn_post_softmax_shape", )
...@@ -112,6 +112,7 @@ decoder_input_data_names = ( ...@@ -112,6 +112,7 @@ decoder_input_data_names = (
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape", "trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape", "trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape", "trg_src_attn_pre_softmax_shape",
......
...@@ -24,6 +24,7 @@ def translate_batch(exe, ...@@ -24,6 +24,7 @@ def translate_batch(exe,
n_best, n_best,
batch_size, batch_size,
n_head, n_head,
d_model,
src_pad_idx, src_pad_idx,
trg_pad_idx, trg_pad_idx,
bos_idx, bos_idx,
...@@ -43,6 +44,11 @@ def translate_batch(exe, ...@@ -43,6 +44,11 @@ def translate_batch(exe,
return_pos=True, return_pos=True,
return_attn_bias=True, return_attn_bias=True,
return_max_len=False) return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder # Append the shape inputs to reshape before and after softmax in encoder
# self attention. # self attention.
enc_in_data = enc_in_data + [ enc_in_data = enc_in_data + [
...@@ -59,9 +65,14 @@ def translate_batch(exe, ...@@ -59,9 +65,14 @@ def translate_batch(exe,
scores = np.zeros((batch_size, beam_size), dtype="float32") scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)] prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)] next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the # Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing. # size of feeded batch is changing.
beam_map = range(batch_size) beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size): def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
""" """
...@@ -98,8 +109,14 @@ def translate_batch(exe, ...@@ -98,8 +109,14 @@ def translate_batch(exe,
[-1e9]).astype("float32") [-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences. # This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile( trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :], src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[beam_size, 1, trg_max_len, 1]) [1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in # Append the shape inputs to reshape before and after softmax in
# decoder self attention. # decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array( trg_slf_attn_pre_softmax_shape = np.array(
...@@ -112,22 +129,24 @@ def translate_batch(exe, ...@@ -112,22 +129,24 @@ def translate_batch(exe,
[-1, trg_src_attn_bias.shape[-1]], dtype="int32") [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array( trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32") trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(enc_output, [beam_size, 1, 1]) enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
enc_output trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams): def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
""" """
Update the input data of decoder mainly by slicing from the previous Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams. input data and dropping the finished instance beams.
""" """
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
enc_output = dec_in_data trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos> trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array( trg_words = np.array(
[ [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx]) beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
...@@ -138,6 +157,7 @@ def translate_batch(exe, ...@@ -138,6 +157,7 @@ def translate_batch(exe,
trg_pos = np.array( trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1]) dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = ( active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] + (np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten() np.array(range(beam_size))[np.newaxis, :]).flatten()
...@@ -152,6 +172,10 @@ def translate_batch(exe, ...@@ -152,6 +172,10 @@ def translate_batch(exe,
trg_src_attn_bias = np.tile(trg_src_attn_bias[ trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1]) [1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in # Append the shape inputs to reshape before and after softmax in
# decoder self attention. # decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array( trg_slf_attn_pre_softmax_shape = np.array(
...@@ -166,9 +190,9 @@ def translate_batch(exe, ...@@ -166,9 +190,9 @@ def translate_batch(exe,
trg_src_attn_bias.shape, dtype="int32") trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :] enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
enc_output trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output) enc_output)
...@@ -177,15 +201,18 @@ def translate_batch(exe, ...@@ -177,15 +201,18 @@ def translate_batch(exe,
feed=dict(zip(dec_in_names, dec_in_data)), feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
predict_all = np.log( predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
-1, :]) [:, -1, :])
predict_all = (predict_all + scores[beam_map].reshape( predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_map) * beam_size, -1])).reshape( [len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1]) [len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token. if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9 predict_all[:, :, unk_idx] = -1e9
active_beams = [] active_beams = []
for inst_idx, beam_idx in enumerate(beam_map): for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :] predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten() if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
...@@ -198,10 +225,14 @@ def translate_batch(exe, ...@@ -198,10 +225,14 @@ def translate_batch(exe,
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx: if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx) active_beams.append(beam_idx)
beam_map = active_beams if len(active_beams) == 0:
if len(beam_map) == 0:
break break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace. # Decode beams and select n_best sequences for each instance by backtrace.
seqs = [ seqs = [
...@@ -215,10 +246,8 @@ def translate_batch(exe, ...@@ -215,10 +246,8 @@ def translate_batch(exe,
def main(): def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program() encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size + 1,
...@@ -228,7 +257,6 @@ def main(): ...@@ -228,7 +257,6 @@ def main():
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder( predict = decoder(
...@@ -273,6 +301,9 @@ def main(): ...@@ -273,6 +301,9 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq, def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx, bos_idx=ModelHyperParams.bos_idx,
...@@ -306,6 +337,7 @@ def main(): ...@@ -306,6 +337,7 @@ def main():
InferTaskConfig.n_best, InferTaskConfig.n_best,
len(data), len(data),
ModelHyperParams.n_head, ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx, ModelHyperParams.bos_idx,
......
...@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers ...@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \ from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names encoder_input_data_names, decoder_input_data_names, label_data_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
...@@ -85,9 +82,10 @@ def multi_head_attention(queries, ...@@ -85,9 +82,10 @@ def multi_head_attention(queries,
return x return x
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size. # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) x=x, shape=[0, -1, n_head, hidden_size // n_head])
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...@@ -103,11 +101,11 @@ def multi_head_attention(queries, ...@@ -103,11 +101,11 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.") raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# FIXME(guosheng): Decouple the program desc with batch_size. # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, x=trans_x,
shape=map(int, shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
""" """
...@@ -205,6 +203,7 @@ def prepare_encoder(src_word, ...@@ -205,6 +203,7 @@ def prepare_encoder(src_word,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
pos_pad_idx=0, pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
The output tensor has a shape of: The output tensor has a shape of:
...@@ -224,9 +223,10 @@ def prepare_encoder(src_word, ...@@ -224,9 +223,10 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
# FIXME(guosheng): Decouple the program desc with batch_size. x=enc_input,
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) shape=[-1, src_max_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout_rate, enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout_rate else enc_input is_test=False) if dropout_rate else enc_input
...@@ -401,20 +401,23 @@ def decoder(dec_input, ...@@ -401,20 +401,23 @@ def decoder(dec_input,
def make_inputs(input_data_names, def make_inputs(input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos, is_pos,
slf_attn_bias_flag, slf_attn_bias_flag,
src_attn_bias_flag, src_attn_bias_flag,
enc_output_flag=False, enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True, slf_attn_shape_flag=True,
src_attn_shape_flag=True): src_attn_shape_flag=True):
""" """
Define the input data layers for the transformer model. Define the input data layers for the transformer model.
""" """
input_layers = [] input_layers = []
# The shapes here act as placeholder. batch_size = 1 # Only for the infer-shape in compile time.
# The shapes set here is to pass the infer-shape in compile time. # The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data( word = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
...@@ -422,6 +425,8 @@ def make_inputs(input_data_names, ...@@ -422,6 +425,8 @@ def make_inputs(input_data_names,
append_batch_size=False) append_batch_size=False)
input_layers += [word] input_layers += [word]
# This is used for position data or label weight. # This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data( pos = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
...@@ -432,6 +437,8 @@ def make_inputs(input_data_names, ...@@ -432,6 +437,8 @@ def make_inputs(input_data_names,
# This input is used to remove attention weights on paddings for the # This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the # encoder and to remove attention weights on subsequent words for the
# decoder. # decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data( slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
...@@ -439,40 +446,56 @@ def make_inputs(input_data_names, ...@@ -439,40 +446,56 @@ def make_inputs(input_data_names,
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_bias] input_layers += [slf_attn_bias]
if src_attn_bias_flag: if src_attn_bias_flag:
# This input is used to remove attention weights on paddings. # This input is used to remove attention weights on paddings. It's used
# in encoder-decoder attention.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data( src_attn_bias = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_bias] input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape the output of embedding layer.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [data_shape]
if slf_attn_shape_flag: if slf_attn_shape_flag:
# This shape input is used to reshape before softmax in self attention.
slf_attn_pre_softmax_shape = layers.data( slf_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[3], shape=[2],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_pre_softmax_shape] input_layers += [slf_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in self attention.
slf_attn_post_softmax_shape = layers.data( slf_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[3], shape=[4],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape] input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag: if src_attn_shape_flag:
src_attn_pre_softmax_shape = layers.data( src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[3], shape=[2],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape] input_layers += [src_attn_pre_softmax_shape]
src_attn_post_softmax_shape = layers.data( src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[3], shape=[4],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_post_softmax_shape] input_layers += [src_attn_post_softmax_shape]
if enc_output_flag: if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data( enc_output = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model], shape=[batch_size, max_length, d_model],
...@@ -497,16 +520,16 @@ def transformer( ...@@ -497,16 +520,16 @@ def transformer(
src_pad_idx, src_pad_idx,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, ): pos_pad_idx, ):
enc_input_layers = make_inputs( enc_inputs = make_inputs(
encoder_input_data_names, encoder_input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos=True, is_pos=True,
slf_attn_bias_flag=True, slf_attn_bias_flag=True,
src_attn_bias_flag=False, src_attn_bias_flag=False,
enc_output_flag=False, enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True, slf_attn_shape_flag=True,
src_attn_shape_flag=False) src_attn_shape_flag=False)
...@@ -522,18 +545,18 @@ def transformer( ...@@ -522,18 +545,18 @@ def transformer(
dropout_rate, dropout_rate,
src_pad_idx, src_pad_idx,
pos_pad_idx, pos_pad_idx,
enc_input_layers, ) enc_inputs, )
dec_input_layers = make_inputs( dec_inputs = make_inputs(
decoder_input_data_names, decoder_input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos=True, is_pos=True,
slf_attn_bias_flag=True, slf_attn_bias_flag=True,
src_attn_bias_flag=True, src_attn_bias_flag=True,
enc_output_flag=False, enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True, slf_attn_shape_flag=True,
src_attn_shape_flag=True) src_attn_shape_flag=True)
...@@ -549,7 +572,7 @@ def transformer( ...@@ -549,7 +572,7 @@ def transformer(
dropout_rate, dropout_rate,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, pos_pad_idx,
dec_input_layers, dec_inputs,
enc_output, ) enc_output, )
# Padding index do not contribute to the total loss. The weights is used to # Padding index do not contribute to the total loss. The weights is used to
...@@ -558,12 +581,12 @@ def transformer( ...@@ -558,12 +581,12 @@ def transformer(
label_data_names, label_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos=False, is_pos=False,
slf_attn_bias_flag=False, slf_attn_bias_flag=False,
src_attn_bias_flag=False, src_attn_bias_flag=False,
enc_output_flag=False, enc_output_flag=False,
data_shape_flag=False,
slf_attn_shape_flag=False, slf_attn_shape_flag=False,
src_attn_shape_flag=False) src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold) cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
...@@ -571,7 +594,7 @@ def transformer( ...@@ -571,7 +594,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost) sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights) token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict return sum_cost, avg_cost, predict, token_num
def wrap_encoder(src_vocab_size, def wrap_encoder(src_vocab_size,
...@@ -585,28 +608,30 @@ def wrap_encoder(src_vocab_size, ...@@ -585,28 +608,30 @@ def wrap_encoder(src_vocab_size,
dropout_rate, dropout_rate,
src_pad_idx, src_pad_idx,
pos_pad_idx, pos_pad_idx,
enc_input_layers=None): enc_inputs=None):
""" """
The wrapper assembles together all needed layers for the encoder. The wrapper assembles together all needed layers for the encoder.
""" """
if enc_input_layers is None: if enc_inputs is None:
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_post_softmax_shape = make_inputs( slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_inputs(
encoder_input_data_names, encoder_input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos=True, is_pos=True,
slf_attn_bias_flag=True, slf_attn_bias_flag=True,
src_attn_bias_flag=False, src_attn_bias_flag=False,
enc_output_flag=False, enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True, slf_attn_shape_flag=True,
src_attn_shape_flag=False) src_attn_shape_flag=False)
else: else:
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_post_softmax_shape = enc_input_layers slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
enc_inputs
enc_input = prepare_encoder( enc_input = prepare_encoder(
src_word, src_word,
src_pos, src_pos,
...@@ -614,7 +639,9 @@ def wrap_encoder(src_vocab_size, ...@@ -614,7 +639,9 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
src_pad_idx, src_pad_idx,
max_length, max_length,
dropout_rate, ) dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
src_slf_attn_bias, src_slf_attn_bias,
...@@ -641,33 +668,33 @@ def wrap_decoder(trg_vocab_size, ...@@ -641,33 +668,33 @@ def wrap_decoder(trg_vocab_size,
dropout_rate, dropout_rate,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, pos_pad_idx,
dec_input_layers=None, dec_inputs=None,
enc_output=None): enc_output=None):
""" """
The wrapper assembles together all needed layers for the decoder. The wrapper assembles together all needed layers for the decoder.
""" """
if dec_input_layers is None: if dec_inputs is None:
# This is used to implement independent decoder program in inference. # This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ trg_data_shape, slf_attn_pre_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
enc_output = make_inputs( src_attn_post_softmax_shape, enc_output = make_inputs(
decoder_input_data_names, decoder_input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos=True, is_pos=True,
slf_attn_bias_flag=True, slf_attn_bias_flag=True,
src_attn_bias_flag=True, src_attn_bias_flag=True,
enc_output_flag=True, enc_output_flag=True,
data_shape_flag=True,
slf_attn_shape_flag=True, slf_attn_shape_flag=True,
src_attn_shape_flag=True) src_attn_shape_flag=True)
else: else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ trg_data_shape, slf_attn_pre_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \ slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
dec_input_layers src_attn_post_softmax_shape = dec_inputs
dec_input = prepare_decoder( dec_input = prepare_decoder(
trg_word, trg_word,
...@@ -676,7 +703,9 @@ def wrap_decoder(trg_vocab_size, ...@@ -676,7 +703,9 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
trg_pad_idx, trg_pad_idx,
max_length, max_length,
dropout_rate, ) dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder( dec_output = decoder(
dec_input, dec_input,
enc_output, enc_output,
...@@ -700,5 +729,5 @@ def wrap_decoder(trg_vocab_size, ...@@ -700,5 +729,5 @@ def wrap_decoder(trg_vocab_size,
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
act="softmax" if dec_input_layers is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -57,7 +57,7 @@ def pad_batch_data(insts, ...@@ -57,7 +57,7 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
...@@ -67,6 +67,10 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -67,6 +67,10 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array( src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32") [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array( src_slf_attn_post_softmax_shape = np.array(
...@@ -79,17 +83,19 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -79,17 +83,19 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[-1, trg_src_attn_bias.shape[-1]], dtype="int32") [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array( trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32") trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False) False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
input_dict = dict( input_dict = dict(
zip(input_data_names, [ zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, src_word, src_pos, src_slf_attn_bias, src_data_shape,
src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape, src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, trg_data_shape, trg_slf_attn_pre_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
lbl_word, lbl_weight trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
])) ]))
return input_dict return input_dict
...@@ -98,7 +104,7 @@ def main(): ...@@ -98,7 +104,7 @@ def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
sum_cost, avg_cost, predict = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_layer, ModelHyperParams.n_head,
...@@ -134,21 +140,24 @@ def main(): ...@@ -134,21 +140,24 @@ def main():
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
def test(exe): def test(exe):
test_sum_costs = [] test_total_cost = 0
test_avg_costs = [] test_total_token = 0
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
test_sum_cost, test_avg_cost = exe.run( test_sum_cost, test_token_num = exe.run(
test_program, feed=data_input, fetch_list=[sum_cost, avg_cost]) test_program,
test_sum_costs.append(test_sum_cost) feed=data_input,
test_avg_costs.append(test_avg_cost) fetch_list=[sum_cost, token_num],
return np.mean(test_sum_costs), np.mean(test_avg_costs) use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters. # Initialize the parameters.
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
...@@ -162,15 +171,11 @@ def main(): ...@@ -162,15 +171,11 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
...@@ -181,13 +186,11 @@ def main(): ...@@ -181,13 +186,11 @@ def main():
(pass_id, batch_id, sum_cost_val, avg_cost_val, (pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)]))) np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference. # Validate and save the model for inference.
val_sum_cost, val_avg_cost = test(exe) val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time() pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time time_consumed = pass_end_time - pass_start_time
print("epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, " print("epoch: %d, val avg loss: %f, val ppl: %f, "
"consumed %fs" % "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
(pass_id, val_sum_cost, val_avg_cost,
np.exp([min(val_avg_cost, 100)]), time_consumed))
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册