未验证 提交 f14db82d 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #783 from guoshengCS/fix-transformer-batchsize

Decouple the program desc with batch_size in Transformer.
......@@ -25,8 +25,7 @@ class TrainTaskConfig(object):
class InferTaskConfig(object):
use_gpu = False
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1
batch_size = 10
# the parameters for beam search.
beam_size = 5
......@@ -103,6 +102,7 @@ encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
......@@ -112,6 +112,7 @@ decoder_input_data_names = (
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
......
......@@ -24,6 +24,7 @@ def translate_batch(exe,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
......@@ -43,6 +44,11 @@ def translate_batch(exe,
return_pos=True,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
......@@ -59,9 +65,14 @@ def translate_batch(exe,
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_map to map the instance idx in batch to beam idx, since the
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_map = range(batch_size)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
......@@ -98,8 +109,14 @@ def translate_batch(exe,
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
......@@ -112,22 +129,24 @@ def translate_batch(exe,
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(enc_output, [beam_size, 1, 1])
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams):
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output = dec_in_data
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
......@@ -138,6 +157,7 @@ def translate_batch(exe,
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
......@@ -152,6 +172,10 @@ def translate_batch(exe,
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
......@@ -166,9 +190,9 @@ def translate_batch(exe,
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \
enc_output
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
......@@ -177,15 +201,18 @@ def translate_batch(exe,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
-1, :])
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
......@@ -198,10 +225,14 @@ def translate_batch(exe,
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
beam_map = active_beams
if len(beam_map) == 0:
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
......@@ -215,10 +246,8 @@ def translate_batch(exe,
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# The current program desc is coupled with batch_size and the only
# supported batch size is 1 currently.
encoder_program = fluid.Program()
model.batch_size = InferTaskConfig.batch_size
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size + 1,
......@@ -228,7 +257,6 @@ def main():
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
......@@ -273,6 +301,9 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
......@@ -306,6 +337,7 @@ def main():
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
......
......@@ -7,9 +7,6 @@ import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec):
"""
......@@ -85,9 +82,10 @@ def multi_head_attention(queries,
return x
hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size.
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
x=x, shape=[0, -1, n_head, hidden_size // n_head])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -103,11 +101,11 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# FIXME(guosheng): Decouple the program desc with batch_size.
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=map(int,
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
......@@ -205,6 +203,7 @@ def prepare_encoder(src_word,
src_max_len,
dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
......@@ -224,9 +223,10 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
enc_input = layers.reshape(
x=enc_input,
shape=[-1, src_max_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout_rate else enc_input
......@@ -401,20 +401,23 @@ def decoder(dec_input,
def make_inputs(input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True):
"""
Define the input data layers for the transformer model.
"""
input_layers = []
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time.
batch_size = 1 # Only for the infer-shape in compile time.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
......@@ -422,6 +425,8 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [word]
# This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
......@@ -432,6 +437,8 @@ def make_inputs(input_data_names,
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
......@@ -439,40 +446,56 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings.
# This input is used to remove attention weights on paddings. It's used
# in encoder-decoder attention.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape the output of embedding layer.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [data_shape]
if slf_attn_shape_flag:
# This shape input is used to reshape before softmax in self attention.
slf_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in self attention.
slf_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_post_softmax_shape]
if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
......@@ -497,16 +520,16 @@ def transformer(
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_input_layers = make_inputs(
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
......@@ -522,18 +545,18 @@ def transformer(
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers, )
enc_inputs, )
dec_input_layers = make_inputs(
dec_inputs = make_inputs(
decoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
......@@ -549,7 +572,7 @@ def transformer(
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers,
dec_inputs,
enc_output, )
# Padding index do not contribute to the total loss. The weights is used to
......@@ -558,12 +581,12 @@ def transformer(
label_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=False,
slf_attn_bias_flag=False,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=False,
slf_attn_shape_flag=False,
src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
......@@ -571,7 +594,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict
return sum_cost, avg_cost, predict, token_num
def wrap_encoder(src_vocab_size,
......@@ -585,28 +608,30 @@ def wrap_encoder(src_vocab_size,
dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_input_layers=None):
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_input_layers is None:
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape = make_inputs(
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_inputs(
encoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
else:
src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape = enc_input_layers
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
......@@ -614,7 +639,9 @@ def wrap_encoder(src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate, )
dropout_rate,
pos_pad_idx,
src_data_shape, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
......@@ -641,33 +668,33 @@ def wrap_decoder(trg_vocab_size,
dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_input_layers=None,
dec_inputs=None,
enc_output=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_input_layers is None:
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
enc_output = make_inputs(
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape, enc_output = make_inputs(
decoder_input_data_names,
n_head,
d_model,
batch_size,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=True,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \
dec_input_layers
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
dec_input = prepare_decoder(
trg_word,
......@@ -676,7 +703,9 @@ def wrap_decoder(trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dropout_rate,
pos_pad_idx,
trg_data_shape, )
dec_output = decoder(
dec_input,
enc_output,
......@@ -700,5 +729,5 @@ def wrap_decoder(trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_input_layers is None else None)
act="softmax" if dec_inputs is None else None)
return predict
......@@ -57,7 +57,7 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head):
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
......@@ -67,6 +67,10 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
......@@ -79,17 +83,19 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head,
False, False, False, False)
lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias,
src_word, src_pos, src_slf_attn_bias, src_data_shape,
src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
lbl_word, lbl_weight
trg_data_shape, trg_slf_attn_pre_softmax_shape,
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
]))
return input_dict
......@@ -98,7 +104,7 @@ def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
sum_cost, avg_cost, predict = transformer(
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
......@@ -134,21 +140,24 @@ def main():
batch_size=TrainTaskConfig.batch_size)
def test(exe):
test_sum_costs = []
test_avg_costs = []
test_total_cost = 0
test_total_token = 0
for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
test_sum_cost, test_avg_cost = exe.run(
test_program, feed=data_input, fetch_list=[sum_cost, avg_cost])
test_sum_costs.append(test_sum_cost)
test_avg_costs.append(test_avg_cost)
return np.mean(test_sum_costs), np.mean(test_avg_costs)
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
feed=data_input,
fetch_list=[sum_cost, token_num],
use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
......@@ -162,15 +171,11 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
......@@ -181,13 +186,11 @@ def main():
(pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference.
val_sum_cost, val_avg_cost = test(exe)
val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, "
"consumed %fs" %
(pass_id, val_sum_cost, val_avg_cost,
np.exp([min(val_avg_cost, 100)]), time_consumed))
print("epoch: %d, val avg loss: %f, val ppl: %f, "
"consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册