提交 f3c247d3 编写于 作者: G guosheng

Decouple the program desc with batch_size in Transformer.

上级 35308832
...@@ -92,7 +92,8 @@ pos_enc_param_names = ( ...@@ -92,7 +92,8 @@ pos_enc_param_names = (
encoder_input_data_names = ( encoder_input_data_names = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", ) "src_slf_attn_bias",
"src_data_shape", )
# Names of all data layers in decoder listed in order. # Names of all data layers in decoder listed in order.
decoder_input_data_names = ( decoder_input_data_names = (
...@@ -100,6 +101,7 @@ decoder_input_data_names = ( ...@@ -100,6 +101,7 @@ decoder_input_data_names = (
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"trg_data_shape",
"enc_output", ) "enc_output", )
# Names of label related data layers listed in order. # Names of label related data layers listed in order.
......
...@@ -13,8 +13,8 @@ from train import pad_batch_data ...@@ -13,8 +13,8 @@ from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length, decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, n_best, batch_size, n_head, d_model, src_pad_idx,
bos_idx, eos_idx): trg_pad_idx, bos_idx, eos_idx):
""" """
Run the encoder program once and run the decoder program multiple times to Run the encoder program once and run the decoder program multiple times to
implement beam search externally. implement beam search externally.
...@@ -28,6 +28,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -28,6 +28,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
return_pos=True, return_pos=True,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True) return_max_len=True)
enc_in_data = enc_in_data[:-1] + [
np.array(
[batch_size, enc_in_data[-1], d_model], dtype="int32")
] # Append the data shape input.
enc_output = exe.run(encoder, enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)), feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0] fetch_list=enc_out_names)[0]
...@@ -35,11 +39,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -35,11 +39,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# Beam Search. # Beam Search.
# To store the beam info. # To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32") scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size prev_branchs = [[] for i in range(batch_size)]
next_ids = [[]] * batch_size next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the # Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing. # size of feeded batch is changing.
beam_map = range(batch_size) beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
""" """
...@@ -64,8 +73,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -64,8 +73,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_words = np.array( trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64") [[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[-1][
-1], enc_in_data[-2], 1 1], enc_in_data[-2], 1
# This is used to remove attention on subsequent words. # This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len)) trg_max_len))
...@@ -77,16 +86,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -77,16 +86,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile( trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :], src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1]) [beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1]) trg_data_shape = np.array(
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output [batch_size * beam_size, trg_max_len, d_model], dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams): def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
""" """
Update the input data of decoder mainly by slicing from the previous Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams. input data and dropping the finished instance beams.
""" """
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos> trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array( trg_words = np.array(
[ [
beam_backtrace( beam_backtrace(
...@@ -98,6 +111,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -98,6 +111,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos = np.array( trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1]) dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = ( active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] + (np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten() np.array(range(beam_size))[np.newaxis, :]).flatten()
...@@ -112,8 +126,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -112,8 +126,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(trg_src_attn_bias[ trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1]) [1, 1, trg_cur_len, 1])
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
enc_output = enc_output[active_beams_indice, :, :] enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output) enc_output)
...@@ -122,13 +139,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -122,13 +139,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
feed=dict(zip(dec_in_names, dec_in_data)), feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
predict_all = np.log( predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
-1, :]) [:, -1, :])
predict_all = (predict_all + scores[beam_map].reshape( predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_map) * beam_size, -1])).reshape( [len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1]) [len(beam_inst_map), beam_size, -1])
active_beams = [] active_beams = []
for inst_idx, beam_idx in enumerate(beam_map): for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :] predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten() if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
...@@ -141,13 +161,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -141,13 +161,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx: if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx) active_beams.append(beam_idx)
beam_map = active_beams if len(active_beams) == 0:
if len(beam_map) == 0:
break break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace. # Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist() return seqs, scores[:, :n_best].tolist()
...@@ -155,10 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -155,10 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
def main(): def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program() encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size + 1,
...@@ -168,7 +193,6 @@ def main(): ...@@ -168,7 +193,6 @@ def main():
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder( predict = decoder(
...@@ -213,16 +237,15 @@ def main(): ...@@ -213,16 +237,15 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program, exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program, encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best, InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, len(data), ModelHyperParams.n_head, ModelHyperParams.d_model,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.eos_idx) ModelHyperParams.bos_idx, ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)): for i in range(len(batch_seqs)):
seqs = batch_seqs[i] seqs = batch_seqs[i]
scores = batch_scores[i] scores = batch_scores[i]
......
...@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers ...@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \ from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names encoder_input_data_names, decoder_input_data_names, label_data_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
...@@ -83,9 +80,10 @@ def multi_head_attention(queries, ...@@ -83,9 +80,10 @@ def multi_head_attention(queries,
return x return x
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size. # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) x=x, shape=[0, -1, n_head, hidden_size // n_head])
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...@@ -101,26 +99,20 @@ def multi_head_attention(queries, ...@@ -101,26 +99,20 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.") raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# FIXME(guosheng): Decouple the program desc with batch_size. # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, x=trans_x,
shape=map(int, shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
""" """
Scaled Dot-Product Attention Scaled Dot-Product Attention
""" """
# FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. # FIXME(guosheng): Remove __softmax when softmax_op supporting high
# rank tensors. softmax_op only supports 2D tensor currently.
# The current implementation of softmax_op only supports 2D tensor, # Otherwise, add extra input data to reshape.
# consequently it cannot be directly used here.
# If to use the reshape_op, Besides, the shape of product inferred in
# compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
def __softmax(x, eps=1e-9): def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x) exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
...@@ -131,6 +123,7 @@ def multi_head_attention(queries, ...@@ -131,6 +123,7 @@ def multi_head_attention(queries,
weights = __softmax( weights = __softmax(
layers.elementwise_add( layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product) x=product, y=attn_bias) if attn_bias else product)
# weights = __softmax(product)
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False) weights, dropout_prob=dropout_rate, is_test=False)
...@@ -177,7 +170,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): ...@@ -177,7 +170,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
return out return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
""" """
Add residual connection, layer normalization and droput to the out tensor Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd. optionally according to the value of process_cmd.
...@@ -195,8 +188,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): ...@@ -195,8 +188,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
param_attr=fluid.initializer.Constant(1.), param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.)) bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
if dropout: if dropout_rate:
out = layers.dropout(out, dropout_prob=dropout, is_test=False) out = layers.dropout(
out, dropout_prob=dropout_rate, is_test=False)
return out return out
...@@ -210,8 +204,9 @@ def prepare_encoder(src_word, ...@@ -210,8 +204,9 @@ def prepare_encoder(src_word,
src_emb_dim, src_emb_dim,
src_pad_idx, src_pad_idx,
src_max_len, src_max_len,
dropout=0., dropout_rate=0.,
pos_pad_idx=0, pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
The output tensor has a shape of: The output tensor has a shape of:
...@@ -231,12 +226,13 @@ def prepare_encoder(src_word, ...@@ -231,12 +226,13 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
# FIXME(guosheng): Decouple the program desc with batch_size. x=enc_input,
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) shape=[-1, src_max_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout, enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout else enc_input is_test=False) if dropout_rate else enc_input
prepare_encoder = partial( prepare_encoder = partial(
...@@ -386,18 +382,21 @@ def decoder(dec_input, ...@@ -386,18 +382,21 @@ def decoder(dec_input,
def make_inputs(input_data_names, def make_inputs(input_data_names,
n_head, n_head,
d_model, d_model,
batch_size,
max_length, max_length,
is_pos, is_pos=True,
slf_attn_bias_flag, slf_attn_bias_flag=True,
src_attn_bias_flag, src_attn_bias_flag=True,
enc_output_flag=False): enc_output_flag=False,
data_shape_flag=True):
""" """
Define the input data layers for the transformer model. Define the input data layers for the transformer model.
""" """
input_layers = [] input_layers = []
# The shapes here act as placeholder. batch_size = 1 # Only for the infer-shape in compile time.
# The shapes set here is to pass the infer-shape in compile time. # The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data( word = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
...@@ -405,6 +404,8 @@ def make_inputs(input_data_names, ...@@ -405,6 +404,8 @@ def make_inputs(input_data_names,
append_batch_size=False) append_batch_size=False)
input_layers += [word] input_layers += [word]
# This is used for position data or label weight. # This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data( pos = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
...@@ -415,6 +416,8 @@ def make_inputs(input_data_names, ...@@ -415,6 +416,8 @@ def make_inputs(input_data_names,
# This input is used to remove attention weights on paddings for the # This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the # encoder and to remove attention weights on subsequent words for the
# decoder. # decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data( slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
...@@ -423,13 +426,26 @@ def make_inputs(input_data_names, ...@@ -423,13 +426,26 @@ def make_inputs(input_data_names,
input_layers += [slf_attn_bias] input_layers += [slf_attn_bias]
if src_attn_bias_flag: if src_attn_bias_flag:
# This input is used to remove attention weights on paddings. # This input is used to remove attention weights on paddings.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data( src_attn_bias = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_bias] input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [data_shape]
if enc_output_flag: if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data( enc_output = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model], shape=[batch_size, max_length, d_model],
...@@ -453,8 +469,8 @@ def transformer( ...@@ -453,8 +469,8 @@ def transformer(
src_pad_idx, src_pad_idx,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, ): pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, enc_inputs = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, False) max_length, True, True, False)
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size,
...@@ -468,10 +484,10 @@ def transformer( ...@@ -468,10 +484,10 @@ def transformer(
dropout_rate, dropout_rate,
src_pad_idx, src_pad_idx,
pos_pad_idx, pos_pad_idx,
enc_input_layers, ) enc_inputs, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, dec_inputs = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, True) max_length, True, True, True)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -485,13 +501,13 @@ def transformer( ...@@ -485,13 +501,13 @@ def transformer(
dropout_rate, dropout_rate,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, pos_pad_idx,
dec_input_layers, dec_inputs,
enc_output, ) enc_output, )
# Padding index do not contribute to the total loss. The weights is used to # Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss. # cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, gold, weights = make_inputs(label_data_names, n_head, d_model, max_length,
max_length, False, False, False) False, False, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold) cost = layers.cross_entropy(input=predict, label=gold)
weighted_cost = cost * weights weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict return layers.reduce_sum(weighted_cost), predict
...@@ -508,17 +524,18 @@ def wrap_encoder(src_vocab_size, ...@@ -508,17 +524,18 @@ def wrap_encoder(src_vocab_size,
dropout_rate, dropout_rate,
src_pad_idx, src_pad_idx,
pos_pad_idx, pos_pad_idx,
enc_input_layers=None): enc_inputs=None):
""" """
The wrapper assembles together all needed layers for the encoder. The wrapper assembles together all needed layers for the encoder.
""" """
if enc_input_layers is None: if enc_inputs is None:
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs( src_word, src_pos, src_slf_attn_bias, src_data_shape = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length, encoder_input_data_names, n_head, d_model, max_length, True, True,
True, True, False) False)
else: else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers src_word, src_pos, src_slf_attn_bias, src_data_shape = enc_inputs
enc_input = prepare_encoder( enc_input = prepare_encoder(
src_word, src_word,
src_pos, src_pos,
...@@ -526,7 +543,9 @@ def wrap_encoder(src_vocab_size, ...@@ -526,7 +543,9 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
src_pad_idx, src_pad_idx,
max_length, max_length,
dropout_rate, ) dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
src_slf_attn_bias, src_slf_attn_bias,
...@@ -551,18 +570,18 @@ def wrap_decoder(trg_vocab_size, ...@@ -551,18 +570,18 @@ def wrap_decoder(trg_vocab_size,
dropout_rate, dropout_rate,
trg_pad_idx, trg_pad_idx,
pos_pad_idx, pos_pad_idx,
dec_input_layers=None, dec_inputs=None,
enc_output=None): enc_output=None):
""" """
The wrapper assembles together all needed layers for the decoder. The wrapper assembles together all needed layers for the decoder.
""" """
if dec_input_layers is None: if dec_inputs is None:
# This is used to implement independent decoder program in inference. # This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length, decoder_input_data_names, n_head, d_model, max_length, True, True,
True, True, True, True) True, True)
else: else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape = dec_inputs
dec_input = prepare_decoder( dec_input = prepare_decoder(
trg_word, trg_word,
...@@ -571,7 +590,9 @@ def wrap_decoder(trg_vocab_size, ...@@ -571,7 +590,9 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
trg_pad_idx, trg_pad_idx,
max_length, max_length,
dropout_rate, ) dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder( dec_output = decoder(
dec_input, dec_input,
enc_output, enc_output,
......
...@@ -56,7 +56,7 @@ def pad_batch_data(insts, ...@@ -56,7 +56,7 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
...@@ -69,10 +69,13 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -69,10 +69,13 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False) False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
input_dict = dict( input_dict = dict(
zip(input_data_names, [ zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, src_word, src_pos, src_slf_attn_bias, src_data_shape, trg_word,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape,
lbl_word, lbl_weight
])) ]))
return input_dict return input_dict
...@@ -119,13 +122,11 @@ def main(): ...@@ -119,13 +122,11 @@ def main():
def test(exe): def test(exe):
test_costs = [] test_costs = []
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
test_cost = exe.run(test_program, test_cost = exe.run(test_program,
feed=data_input, feed=data_input,
fetch_list=[cost])[0] fetch_list=[cost])[0]
...@@ -143,15 +144,11 @@ def main(): ...@@ -143,15 +144,11 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head) ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册