提交 f3c247d3 编写于 作者: G guosheng

Decouple the program desc with batch_size in Transformer.

上级 35308832
......@@ -92,7 +92,8 @@ pos_enc_param_names = (
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
"src_slf_attn_bias",
"src_data_shape", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
......@@ -100,6 +101,7 @@ decoder_input_data_names = (
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_data_shape",
"enc_output", )
# Names of label related data layers listed in order.
......
......@@ -13,8 +13,8 @@ from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
n_best, batch_size, n_head, d_model, src_pad_idx,
trg_pad_idx, bos_idx, eos_idx):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
......@@ -28,6 +28,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
return_pos=True,
return_attn_bias=True,
return_max_len=True)
enc_in_data = enc_in_data[:-1] + [
np.array(
[batch_size, enc_in_data[-1], d_model], dtype="int32")
] # Append the data shape input.
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
......@@ -35,11 +39,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[]] * batch_size
next_ids = [[]] * batch_size
# Use beam_map to map the instance idx in batch to beam idx, since the
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_map = range(batch_size)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
"""
......@@ -64,8 +73,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[
-1], enc_in_data[-2], 1
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[-1][
1], enc_in_data[-2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
......@@ -77,16 +86,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
enc_output = np.tile(enc_output, [beam_size, 1, 1])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams):
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(
......@@ -98,6 +111,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
......@@ -112,8 +126,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
......@@ -122,13 +139,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
-1, :])
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
......@@ -141,13 +161,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
beam_map = active_beams
if len(beam_map) == 0:
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
......@@ -155,10 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
......@@ -168,7 +193,6 @@ def main():
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
......@@ -213,16 +237,15 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
len(data), ModelHyperParams.n_head, ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx, ModelHyperParams.eos_idx)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
scores = batch_scores[i]
......
......@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec):
"""
......@@ -83,9 +80,10 @@ def multi_head_attention(queries,
return x
hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size.
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
x=x, shape=[0, -1, n_head, hidden_size // n_head])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -101,26 +99,20 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# FIXME(guosheng): Decouple the program desc with batch_size.
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=map(int,
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
Scaled Dot-Product Attention
"""
# FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.
# The current implementation of softmax_op only supports 2D tensor,
# consequently it cannot be directly used here.
# If to use the reshape_op, Besides, the shape of product inferred in
# compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
# FIXME(guosheng): Remove __softmax when softmax_op supporting high
# rank tensors. softmax_op only supports 2D tensor currently.
# Otherwise, add extra input data to reshape.
def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
......@@ -131,6 +123,7 @@ def multi_head_attention(queries,
weights = __softmax(
layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product)
# weights = __softmax(product)
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
......@@ -177,7 +170,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
......@@ -195,8 +188,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False)
if dropout_rate:
out = layers.dropout(
out, dropout_prob=dropout_rate, is_test=False)
return out
......@@ -210,8 +204,9 @@ def prepare_encoder(src_word,
src_emb_dim,
src_pad_idx,
src_max_len,
dropout=0.,
dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
......@@ -231,12 +226,13 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
enc_input = layers.reshape(
x=enc_input,
shape=[-1, src_max_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input, dropout_prob=dropout,
is_test=False) if dropout else enc_input
enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout_rate else enc_input
prepare_encoder = partial(
......@@ -386,18 +382,21 @@ def decoder(dec_input,
def make_inputs(input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False):
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=False,
data_shape_flag=True):
"""
Define the input data layers for the transformer model.
"""
input_layers = []
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time.
batch_size = 1 # Only for the infer-shape in compile time.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
......@@ -405,6 +404,8 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [word]
# This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
......@@ -415,6 +416,8 @@ def make_inputs(input_data_names,
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
......@@ -423,13 +426,26 @@ def make_inputs(input_data_names,
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [data_shape]
if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
......@@ -453,8 +469,8 @@ def transformer(
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, False)
enc_inputs = make_inputs(encoder_input_data_names, n_head, d_model,
max_length, True, True, False)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -468,10 +484,10 @@ def transformer(
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers, )
enc_inputs, )
dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model,
batch_size, max_length, True, True, True)
dec_inputs = make_inputs(decoder_input_data_names, n_head, d_model,
max_length, True, True, True)
predict = wrap_decoder(
trg_vocab_size,
......@@ -485,13 +501,13 @@ def transformer(
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers,
dec_inputs,
enc_output, )
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size,
max_length, False, False, False)
gold, weights = make_inputs(label_data_names, n_head, d_model, max_length,
False, False, False, False, False)
cost = layers.cross_entropy(input=predict, label=gold)
weighted_cost = cost * weights
return layers.reduce_sum(weighted_cost), predict
......@@ -508,17 +524,18 @@ def wrap_encoder(src_vocab_size,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers=None):
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_input_layers is None:
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_inputs(
encoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, False)
src_word, src_pos, src_slf_attn_bias, src_data_shape = make_inputs(
encoder_input_data_names, n_head, d_model, max_length, True, True,
False)
else:
src_word, src_pos, src_slf_attn_bias = enc_input_layers
src_word, src_pos, src_slf_attn_bias, src_data_shape = enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
......@@ -526,7 +543,9 @@ def wrap_encoder(src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate, )
dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
......@@ -551,18 +570,18 @@ def wrap_decoder(trg_vocab_size,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers=None,
dec_inputs=None,
enc_output=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_input_layers is None:
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, batch_size, max_length,
True, True, True, True)
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = make_inputs(
decoder_input_data_names, n_head, d_model, max_length, True, True,
True, True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape = dec_inputs
dec_input = prepare_decoder(
trg_word,
......@@ -571,7 +590,9 @@ def wrap_decoder(trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder(
dec_input,
enc_output,
......
......@@ -56,7 +56,7 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head):
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
......@@ -69,10 +69,13 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
src_word, src_pos, src_slf_attn_bias, src_data_shape, trg_word,
trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape,
lbl_word, lbl_weight
]))
return input_dict
......@@ -119,13 +122,11 @@ def main():
def test(exe):
test_costs = []
for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_cost = exe.run(test_program,
feed=data_input,
fetch_list=[cost])[0]
......@@ -143,15 +144,11 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册