未验证 提交 6966c992 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1150 from reyoung/feature/remove_embedding_softmax_reshape

Remove reshape op for embedding and softmax
......@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size * seq_len, 1L), "int64", 2],
"src_word": [(batch_size, seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size * seq_len, 1L), "int64"],
"src_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size * seq_len, 1L), "int64",
"trg_word": [(batch_size, seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size * seq_len, 1L), "int64"],
"trg_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
......@@ -151,18 +145,6 @@ input_descs = {
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
......@@ -193,22 +175,12 @@ encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
......@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
# fast_decoder_util_input_fields = (
# "trg_slf_attn_pre_softmax_shape_delta",
# "trg_slf_attn_post_softmax_shape_delta", )
......@@ -85,239 +85,6 @@ def parse_args():
return args
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally. This is deprecated since a faster beam
search decoder based solely on Fluid operators has been added.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
......@@ -339,93 +106,8 @@ def post_process_seq(seq,
seq)
def py_infer(test_data, trg_idx2word, use_wordpiece):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = encoder_program.inference_optimize()
decoder_program = decoder_program.inference_optimize()
for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_data_input_fields + encoder_util_input_fields,
[enc_output.name],
decoder_program,
decoder_data_input_fields[:-1] + decoder_util_input_fields +
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_out_len,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
if use_wordpiece:
print(util.subword_ids_to_str(seq, trg_idx2word))
else:
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
bos_idx, n_head, d_model, place):
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
d_model, place):
"""
Put all padded data needed by beam search decoder into a dict.
"""
......@@ -435,25 +117,9 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, 1, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, 1], dtype="int32") # only the first time step
trg_slf_attn_post_softmax_shape = np.array(
[-1, n_head, 1, 1], dtype="int32") # only the first time step
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
# These inputs are used to change the shapes in the loop of while op.
attn_pre_softmax_shape_delta = np.array([0, 1], dtype="int32")
attn_post_softmax_shape_delta = np.array([0, 0, 0, 1], dtype="int32")
trg_word = trg_word.reshape(-1, 1, 1)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
......@@ -465,7 +131,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(
trg_word, dtype="float32"),
trg_word, dtype="float32").reshape(-1, 1),
place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
......@@ -474,16 +140,8 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta
]))
input_dict = dict(data_input_dict.items() + util_input_dict.items())
input_dict = dict(data_input_dict.items())
return input_dict
......@@ -515,7 +173,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
......
......@@ -29,8 +29,6 @@ def multi_head_attention(queries,
d_model,
n_head=1,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None,
cache=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
......@@ -101,14 +99,9 @@ def multi_head_attention(queries,
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
......@@ -191,7 +184,6 @@ def prepare_encoder(src_word,
src_emb_dim,
src_max_len,
dropout_rate=0.,
src_data_shape=None,
word_emb_param_name=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
......@@ -205,6 +197,7 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding(
src_pos,
......@@ -212,10 +205,6 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input,
dropout_prob=dropout_rate,
......@@ -236,18 +225,16 @@ def encoder_layer(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
dropout_rate=0.):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
attn_output = multi_head_attention(enc_input, enc_input, enc_input,
attn_bias, d_key, d_value, d_model,
n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
......@@ -262,25 +249,14 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
dropout_rate=0.):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout_rate)
enc_input = enc_output
return enc_output
......@@ -295,10 +271,6 @@ def decoder_layer(dec_input,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
......@@ -314,8 +286,6 @@ def decoder_layer(dec_input,
d_model,
n_head,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer(
dec_input,
......@@ -331,9 +301,7 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
dropout_rate,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dropout_rate, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
......@@ -362,15 +330,15 @@ def decoder(dec_input,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
cache = None
if caches is not None:
cache = caches[i]
dec_output = decoder_layer(
dec_input,
enc_output,
......@@ -382,11 +350,7 @@ def decoder(dec_input,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
cache=cache)
dec_input = dec_output
return dec_output
......@@ -425,8 +389,7 @@ def transformer(
assert src_vocab_size == src_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
enc_inputs = make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
enc_inputs = make_all_inputs(encoder_data_input_fields)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -441,8 +404,7 @@ def transformer(
weight_sharing,
enc_inputs, )
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
decoder_util_input_fields)
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
predict = wrap_decoder(
trg_vocab_size,
......@@ -466,8 +428,10 @@ def transformer(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
......@@ -494,13 +458,10 @@ def wrap_encoder(src_vocab_size,
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
src_word, src_pos, src_slf_attn_bias = \
make_all_inputs(encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
src_word, src_pos, src_slf_attn_bias = \
enc_inputs
enc_input = prepare_encoder(
src_word,
......@@ -509,20 +470,9 @@ def wrap_encoder(src_vocab_size,
d_model,
max_length,
dropout_rate,
src_data_shape,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, dropout_rate)
return enc_output
......@@ -545,15 +495,10 @@ def wrap_decoder(trg_vocab_size,
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = make_all_inputs(
enc_output = make_all_inputs(
decoder_data_input_fields + decoder_util_input_fields)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
......@@ -562,7 +507,6 @@ def wrap_decoder(trg_vocab_size,
d_model,
max_length,
dropout_rate,
trg_data_shape,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
dec_output = decoder(
......@@ -577,28 +521,20 @@ def wrap_decoder(trg_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
caches, )
caches=caches)
# Return logits for training and probs for inference.
if weight_sharing:
predict = layers.reshape(
x=layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
predict = layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True)
else:
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2)
if dec_inputs is None:
predict = layers.softmax(predict)
return predict
......@@ -624,12 +560,8 @@ def fast_decode(
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
start_tokens, init_scores, trg_src_attn_bias = \
make_all_inputs(fast_decoder_data_input_fields )
def beam_search():
max_len = layers.fill_constant(
......@@ -639,7 +571,8 @@ def fast_decode(
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(start_tokens, step_idx)
ids = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
......@@ -658,6 +591,7 @@ def fast_decode(
} for i in range(n_layer)]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
......@@ -674,7 +608,7 @@ def fast_decode(
x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1,
shape=[-1, 1],
shape=[-1, 1, 1],
dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
......@@ -690,12 +624,11 @@ def fast_decode(
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs=(
pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape,
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output,
caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
......@@ -712,6 +645,7 @@ def fast_decode(
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
layers.array_write(selected_ids, i=step_idx, array=ids)
......@@ -721,17 +655,6 @@ def fast_decode(
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
layers.assign(
layers.elementwise_add(
x=slf_attn_pre_softmax_shape,
y=attn_pre_softmax_shape_delta),
slf_attn_pre_softmax_shape)
layers.assign(
layers.elementwise_add(
x=slf_attn_post_softmax_shape,
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
......
......@@ -96,7 +96,6 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
start_time = time.time()
exec_time = 0.0
......@@ -108,12 +107,12 @@ def train_loop(exe, train_progm, init, num_iters, train_data, dev_count,
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items()
feed_kv_pairs = data_input_dict.items()
lr_rate = lr_scheduler.update_learning_rate()
feed_kv_pairs += {lr_scheduler.learning_rate.name: lr_rate}.items()
feed_list.append(dict(feed_kv_pairs))
......
......@@ -180,34 +180,23 @@ def pad_batch_data(insts,
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
trg_pad_idx, n_head, d_model):
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
[-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
......@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
]))
return data_input_dict, util_input_dict, np.asarray(
[num_token], dtype="float32")
return data_input_dict, np.asarray([num_token], dtype="float32")
def read_multiple(reader, count, clip_last=True):
......@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items()))
feed_list.append(data_input_dict)
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
......@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count,
......@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items(
)
feed_kv_pairs = data_input_dict.items()
if args.local:
feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate
......@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe)
data_input_names[:-2], [predict], exe)
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册