未验证 提交 6966c992 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1150 from reyoung/feature/remove_embedding_softmax_reshape

Remove reshape op for embedding and softmax
...@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length ...@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size * seq_len, 1L), "int64", 2], "src_word": [(batch_size, seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size * seq_len, 1L), "int64"], "src_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size * seq_len, 1L), "int64", "trg_word": [(batch_size, seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder. 2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size * seq_len, 1L), "int64"], "trg_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
...@@ -151,18 +145,6 @@ input_descs = { ...@@ -151,18 +145,6 @@ input_descs = {
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference. # This input is used in independent decoder program for inference.
# The actual data shape of enc_output is: # The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model] # [batch_size, max_src_len_in_batch, d_model]
...@@ -193,22 +175,12 @@ encoder_data_input_fields = ( ...@@ -193,22 +175,12 @@ encoder_data_input_fields = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", ) "src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
decoder_data_input_fields = ( decoder_data_input_fields = (
"trg_word", "trg_word",
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"enc_output", ) "enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", ) "lbl_weight", )
...@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = ( ...@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"trg_src_attn_bias", ) "trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + ( # fast_decoder_util_input_fields = (
"trg_slf_attn_pre_softmax_shape_delta", # "trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", ) # "trg_slf_attn_post_softmax_shape_delta", )
...@@ -85,239 +85,6 @@ def parse_args(): ...@@ -85,239 +85,6 @@ def parse_args():
return args return args
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
def post_process_seq(seq, def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx, bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx, eos_idx=ModelHyperParams.eos_idx,
...@@ -339,93 +106,8 @@ def post_process_seq(seq, ...@@ -339,93 +106,8 @@ def post_process_seq(seq,
seq) seq)
def py_infer(test_data, trg_idx2word, use_wordpiece): def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
""" d_model, place):
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = encoder_program.inference_optimize()
decoder_program = decoder_program.inference_optimize()
for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_data_input_fields + encoder_util_input_fields,
[enc_output.name],
decoder_program,
decoder_data_input_fields[:-1] + decoder_util_input_fields +
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_out_len,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
if use_wordpiece:
print(util.subword_ids_to_str(seq, trg_idx2word))
else:
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
""" """
Put all padded data needed by beam search decoder into a dict. Put all padded data needed by beam search decoder into a dict.
""" """
...@@ -435,25 +117,9 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -435,25 +117,9 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32") [1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
# These shape tensors are used in reshape_op. src_word = src_word.reshape(-1, src_max_len, 1)
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32") src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
def to_lodtensor(data, place, lod=None): def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor() data_tensor = fluid.LoDTensor()
...@@ -465,7 +131,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -465,7 +131,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
# beamsearch_op must use tensors with lod # beamsearch_op must use tensors with lod
init_score = to_lodtensor( init_score = to_lodtensor(
np.zeros_like( np.zeros_like(
trg_word, dtype="float32"), trg_word, dtype="float32").reshape(-1, 1),
place, [range(trg_word.shape[0] + 1)] * 2) place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2) trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
...@@ -474,16 +140,8 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -474,16 +140,8 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, init_score, src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias trg_src_attn_bias
])) ]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items()) input_dict = dict(data_input_dict.items())
return input_dict return input_dict
...@@ -515,7 +173,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -515,7 +173,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields, data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx, ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place) ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program, seq_ids, seq_scores = exe.run(infer_program,
......
...@@ -29,8 +29,6 @@ def multi_head_attention(queries, ...@@ -29,8 +29,6 @@ def multi_head_attention(queries,
d_model, d_model,
n_head=1, n_head=1,
dropout_rate=0., dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None,
cache=None): cache=None):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
...@@ -101,14 +99,9 @@ def multi_head_attention(queries, ...@@ -101,14 +99,9 @@ def multi_head_attention(queries,
""" """
scaled_q = layers.scale(x=q, scale=d_model**-0.5) scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape( if attn_bias:
x=layers.elementwise_add( product += attn_bias
x=product, y=attn_bias) if attn_bias else product, weights = layers.softmax(product)
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, weights,
...@@ -191,7 +184,6 @@ def prepare_encoder(src_word, ...@@ -191,7 +184,6 @@ def prepare_encoder(src_word,
src_emb_dim, src_emb_dim,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
src_data_shape=None,
word_emb_param_name=None, word_emb_param_name=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
...@@ -205,6 +197,7 @@ def prepare_encoder(src_word, ...@@ -205,6 +197,7 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=word_emb_param_name, name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
...@@ -212,10 +205,6 @@ def prepare_encoder(src_word, ...@@ -212,10 +205,6 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, enc_input,
dropout_prob=dropout_rate, dropout_prob=dropout_rate,
...@@ -236,18 +225,16 @@ def encoder_layer(enc_input, ...@@ -236,18 +225,16 @@ def encoder_layer(enc_input,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.):
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder. """The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization with the post_process_layer to add residual connection, layer normalization
and droput. and droput.
""" """
attn_output = multi_head_attention( attn_output = multi_head_attention(enc_input, enc_input, enc_input,
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan", attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate) dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
...@@ -262,25 +249,14 @@ def encoder(enc_input, ...@@ -262,25 +249,14 @@ def encoder(enc_input,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.):
pre_softmax_shape=None,
post_softmax_shape=None):
""" """
The encoder is composed of a stack of identical layers returned by calling The encoder is composed of a stack of identical layers returned by calling
encoder_layer. encoder_layer.
""" """
for i in range(n_layer): for i in range(n_layer):
enc_output = encoder_layer( enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
enc_input, d_model, d_inner_hid, dropout_rate)
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_input = enc_output enc_input = enc_output
return enc_output return enc_output
...@@ -295,10 +271,6 @@ def decoder_layer(dec_input, ...@@ -295,10 +271,6 @@ def decoder_layer(dec_input,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
cache=None): cache=None):
""" The layer to be stacked in decoder part. """ The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except The structure of this module is similar to that in the encoder part except
...@@ -314,8 +286,6 @@ def decoder_layer(dec_input, ...@@ -314,8 +286,6 @@ def decoder_layer(dec_input,
d_model, d_model,
n_head, n_head,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
cache, ) cache, )
slf_attn_output = post_process_layer( slf_attn_output = post_process_layer(
dec_input, dec_input,
...@@ -331,9 +301,7 @@ def decoder_layer(dec_input, ...@@ -331,9 +301,7 @@ def decoder_layer(dec_input,
d_value, d_value,
d_model, d_model,
n_head, n_head,
dropout_rate, dropout_rate, )
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
enc_attn_output = post_process_layer( enc_attn_output = post_process_layer(
slf_attn_output, slf_attn_output,
enc_attn_output, enc_attn_output,
...@@ -362,15 +330,15 @@ def decoder(dec_input, ...@@ -362,15 +330,15 @@ def decoder(dec_input,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
caches=None): caches=None):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
for i in range(n_layer): for i in range(n_layer):
cache = None
if caches is not None:
cache = caches[i]
dec_output = decoder_layer( dec_output = decoder_layer(
dec_input, dec_input,
enc_output, enc_output,
...@@ -382,11 +350,7 @@ def decoder(dec_input, ...@@ -382,11 +350,7 @@ def decoder(dec_input,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape, cache=cache)
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
...@@ -425,8 +389,7 @@ def transformer( ...@@ -425,8 +389,7 @@ def transformer(
assert src_vocab_size == src_vocab_size, ( assert src_vocab_size == src_vocab_size, (
"Vocabularies in source and target should be same for weight sharing." "Vocabularies in source and target should be same for weight sharing."
) )
enc_inputs = make_all_inputs(encoder_data_input_fields + enc_inputs = make_all_inputs(encoder_data_input_fields)
encoder_util_input_fields)
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size,
...@@ -441,8 +404,7 @@ def transformer( ...@@ -441,8 +404,7 @@ def transformer(
weight_sharing, weight_sharing,
enc_inputs, ) enc_inputs, )
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] + dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
decoder_util_input_fields)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -466,8 +428,10 @@ def transformer( ...@@ -466,8 +428,10 @@ def transformer(
label=layers.one_hot( label=layers.one_hot(
input=label, depth=trg_vocab_size), input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps) epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=predict, logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
label=label, label=label,
soft_label=True if label_smooth_eps else False) soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights weighted_cost = cost * weights
...@@ -494,13 +458,10 @@ def wrap_encoder(src_vocab_size, ...@@ -494,13 +458,10 @@ def wrap_encoder(src_vocab_size,
""" """
if enc_inputs is None: if enc_inputs is None:
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \ src_word, src_pos, src_slf_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \ make_all_inputs(encoder_data_input_fields)
make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
else: else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \ src_word, src_pos, src_slf_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
enc_inputs enc_inputs
enc_input = prepare_encoder( enc_input = prepare_encoder(
src_word, src_word,
...@@ -509,20 +470,9 @@ def wrap_encoder(src_vocab_size, ...@@ -509,20 +470,9 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
max_length, max_length,
dropout_rate, dropout_rate,
src_data_shape,
word_emb_param_name=word_emb_param_names[0]) word_emb_param_name=word_emb_param_names[0])
enc_output = encoder( enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
enc_input, d_value, d_model, d_inner_hid, dropout_rate)
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
return enc_output return enc_output
...@@ -545,15 +495,10 @@ def wrap_decoder(trg_vocab_size, ...@@ -545,15 +495,10 @@ def wrap_decoder(trg_vocab_size,
if dec_inputs is None: if dec_inputs is None:
# This is used to implement independent decoder program in inference. # This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \ enc_output = make_all_inputs(
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = make_all_inputs(
decoder_data_input_fields + decoder_util_input_fields) decoder_data_input_fields + decoder_util_input_fields)
else: else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
dec_input = prepare_decoder( dec_input = prepare_decoder(
trg_word, trg_word,
...@@ -562,7 +507,6 @@ def wrap_decoder(trg_vocab_size, ...@@ -562,7 +507,6 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
max_length, max_length,
dropout_rate, dropout_rate,
trg_data_shape,
word_emb_param_name=word_emb_param_names[0] word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1]) if weight_sharing else word_emb_param_names[1])
dec_output = decoder( dec_output = decoder(
...@@ -577,28 +521,20 @@ def wrap_decoder(trg_vocab_size, ...@@ -577,28 +521,20 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape, caches=caches)
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
if weight_sharing: if weight_sharing:
predict = layers.reshape( predict = layers.matmul(
x=layers.matmul( x=dec_output,
x=dec_output, y=fluid.get_var(word_emb_param_names[0]),
y=fluid.get_var(word_emb_param_names[0]), transpose_y=True)
transpose_y=True),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
else: else:
predict = layers.reshape( predict = layers.fc(input=dec_output,
x=layers.fc(input=dec_output, size=trg_vocab_size,
size=trg_vocab_size, bias_attr=False,
bias_attr=False, num_flatten_dims=2)
num_flatten_dims=2), if dec_inputs is None:
shape=[-1, trg_vocab_size], predict = layers.softmax(predict)
act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -624,12 +560,8 @@ def fast_decode( ...@@ -624,12 +560,8 @@ def fast_decode(
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing) dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \ start_tokens, init_scores, trg_src_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ make_all_inputs(fast_decoder_data_input_fields )
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
def beam_search(): def beam_search():
max_len = layers.fill_constant( max_len = layers.fill_constant(
...@@ -639,7 +571,8 @@ def fast_decode( ...@@ -639,7 +571,8 @@ def fast_decode(
cond = layers.less_than(x=step_idx, y=max_len) cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond) while_op = layers.While(cond)
# array states will be stored for each step. # array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx) ids = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx) scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step. # cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant # caches contains states of history steps to reduce redundant
...@@ -658,6 +591,7 @@ def fast_decode( ...@@ -658,6 +591,7 @@ def fast_decode(
} for i in range(n_layer)] } for i in range(n_layer)]
with while_op.block(): with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx) pre_ids = layers.array_read(array=ids, i=step_idx)
pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
pre_scores = layers.array_read(array=scores, i=step_idx) pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be # sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids. # used in beam search to sift states corresponding to selected ids.
...@@ -674,7 +608,7 @@ def fast_decode( ...@@ -674,7 +608,7 @@ def fast_decode(
x=layers.fill_constant_batch_size_like( x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1, value=1,
shape=[-1, 1], shape=[-1, 1, 1],
dtype=pre_ids.dtype), dtype=pre_ids.dtype),
y=layers.increment( y=layers.increment(
x=step_idx, value=1.0, in_place=False), x=step_idx, value=1.0, in_place=False),
...@@ -690,12 +624,11 @@ def fast_decode( ...@@ -690,12 +624,11 @@ def fast_decode(
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
weight_sharing, weight_sharing,
dec_inputs=( dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk( topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size) input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
...@@ -712,6 +645,7 @@ def fast_decode( ...@@ -712,6 +645,7 @@ def fast_decode(
scores=accu_scores, scores=accu_scores,
beam_size=beam_size, beam_size=beam_size,
end_id=eos_idx) end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True) layers.increment(x=step_idx, value=1.0, in_place=True)
# update states # update states
layers.array_write(selected_ids, i=step_idx, array=ids) layers.array_write(selected_ids, i=step_idx, array=ids)
...@@ -721,17 +655,6 @@ def fast_decode( ...@@ -721,17 +655,6 @@ def fast_decode(
for i in range(n_layer): for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"]) layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"]) layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
y=attn_pre_softmax_shape_delta),
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
length_cond = layers.less_than(x=step_idx, y=max_len) length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
......
...@@ -96,7 +96,6 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count, ...@@ -96,7 +96,6 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[: data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields -1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
start_time = time.time() start_time = time.time()
exec_time = 0.0 exec_time = 0.0
...@@ -108,12 +107,12 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count, ...@@ -108,12 +107,12 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count,
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input( data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items() feed_kv_pairs = data_input_dict.items()
lr_rate = lr_scheduler.update_learning_rate() lr_rate = lr_scheduler.update_learning_rate()
feed_kv_pairs += {lr_scheduler.learning_rate.name: lr_rate}.items() feed_kv_pairs += {lr_scheduler.learning_rate.name: lr_rate}.items()
feed_list.append(dict(feed_kv_pairs)) feed_list.append(dict(feed_kv_pairs))
......
...@@ -180,34 +180,23 @@ def pad_batch_data(insts, ...@@ -180,34 +180,23 @@ def pad_batch_data(insts,
return return_list if len(return_list) > 1 else return_list[0] return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
trg_pad_idx, n_head, d_model): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
[-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
lbl_word, lbl_weight, num_token = pad_batch_data( lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts], [inst[2] for inst in insts],
trg_pad_idx, trg_pad_idx,
...@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
])) ]))
util_input_dict = dict( return data_input_dict, np.asarray([num_token], dtype="float32")
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
]))
return data_input_dict, util_input_dict, np.asarray(
[num_token], dtype="float32")
def read_multiple(reader, count, clip_last=True): def read_multiple(reader, count, clip_last=True):
...@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input( data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append( feed_list.append(data_input_dict)
dict(data_input_dict.items() + util_input_dict.items()))
outs = exe.run(feed=feed_list, outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name]) fetch_list=[sum_cost.name, token_num.name])
...@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[: data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields -1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
if args.val_file_pattern is not None: if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count, test = test_context(train_progm, avg_cost, train_exe, dev_count,
...@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input( data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items( feed_kv_pairs = data_input_dict.items()
)
if args.local: if args.local:
feed_kv_pairs += { feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate lr_scheduler.learning_rate.name: lr_rate
...@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe) data_input_names[:-2], [predict], exe)
if args.enable_ce: # For CE if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost)) print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册