提交 be5c8e56 编写于 作者: Y Yu Yang

Remove reshape op for embedding and softmax

上级 c34bb5f1
......@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size * seq_len, 1L), "int64", 2],
"src_word": [(batch_size, seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size * seq_len, 1L), "int64"],
"src_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size * seq_len, 1L), "int64",
"trg_word": [(batch_size, seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size * seq_len, 1L), "int64"],
"trg_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
......@@ -151,18 +145,6 @@ input_descs = {
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
......@@ -193,22 +175,12 @@ encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
......@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + (
fast_decoder_util_input_fields = (
"trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", )
......@@ -29,8 +29,6 @@ def multi_head_attention(queries,
d_model,
n_head=1,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None,
cache=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
......@@ -101,14 +99,9 @@ def multi_head_attention(queries,
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
......@@ -191,7 +184,6 @@ def prepare_encoder(src_word,
src_emb_dim,
src_max_len,
dropout_rate=0.,
src_data_shape=None,
word_emb_param_name=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
......@@ -212,10 +204,6 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input,
dropout_prob=dropout_rate,
......@@ -236,18 +224,16 @@ def encoder_layer(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
dropout_rate=0.):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
attn_output = multi_head_attention(enc_input, enc_input, enc_input,
attn_bias, d_key, d_value, d_model,
n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
......@@ -262,25 +248,14 @@ def encoder(enc_input,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
dropout_rate=0.):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout_rate)
enc_input = enc_output
return enc_output
......@@ -295,10 +270,6 @@ def decoder_layer(dec_input,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
......@@ -314,8 +285,6 @@ def decoder_layer(dec_input,
d_model,
n_head,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
cache, )
slf_attn_output = post_process_layer(
dec_input,
......@@ -331,9 +300,7 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
dropout_rate,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dropout_rate, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
......@@ -362,10 +329,6 @@ def decoder(dec_input,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
......@@ -381,12 +344,7 @@ def decoder(dec_input,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dropout_rate, )
dec_input = dec_output
return dec_output
......@@ -425,8 +383,7 @@ def transformer(
assert src_vocab_size == src_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
enc_inputs = make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
enc_inputs = make_all_inputs(encoder_data_input_fields)
enc_output = wrap_encoder(
src_vocab_size,
......@@ -441,8 +398,7 @@ def transformer(
weight_sharing,
enc_inputs, )
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
decoder_util_input_fields)
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
predict = wrap_decoder(
trg_vocab_size,
......@@ -466,8 +422,10 @@ def transformer(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
......@@ -494,13 +452,11 @@ def wrap_encoder(src_vocab_size,
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
src_word, src_pos, src_slf_attn_bias = \
make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
encoder_util_input_fields)
else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
src_word, src_pos, src_slf_attn_bias = \
enc_inputs
enc_input = prepare_encoder(
src_word,
......@@ -509,20 +465,9 @@ def wrap_encoder(src_vocab_size,
d_model,
max_length,
dropout_rate,
src_data_shape,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, dropout_rate)
return enc_output
......@@ -545,15 +490,10 @@ def wrap_decoder(trg_vocab_size,
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = make_all_inputs(
enc_output = make_all_inputs(
decoder_data_input_fields + decoder_util_input_fields)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
......@@ -562,7 +502,6 @@ def wrap_decoder(trg_vocab_size,
d_model,
max_length,
dropout_rate,
trg_data_shape,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
dec_output = decoder(
......@@ -576,29 +515,20 @@ def wrap_decoder(trg_vocab_size,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
caches, )
dropout_rate, )
# Return logits for training and probs for inference.
if weight_sharing:
predict = layers.reshape(
x=layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
predict = layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True)
predict = layers.softmax(predict)
else:
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2,
act='softmax')
return predict
......@@ -625,11 +555,11 @@ def fast_decode(
d_key, d_value, d_model, d_inner_hid,
dropout_rate, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \
src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \
attn_pre_softmax_shape_delta, attn_post_softmax_shape_delta = \
make_all_inputs(fast_decoder_data_input_fields +
fast_decoder_util_input_fields)
fast_decoder_util_input_fields)
def beam_search():
max_len = layers.fill_constant(
......
......@@ -180,34 +180,23 @@ def pad_batch_data(insts,
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
trg_pad_idx, n_head, d_model):
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
[-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
......@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]))
util_input_dict = dict(
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
]))
return data_input_dict, util_input_dict, np.asarray(
[num_token], dtype="float32")
return data_input_dict, np.asarray([num_token], dtype="float32")
def read_multiple(reader, count, clip_last=True):
......@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items()))
feed_list.append(data_input_dict)
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
......@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count,
......@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items(
)
feed_kv_pairs = data_input_dict.items()
if args.local:
feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate
......@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe)
data_input_names[:-2], [predict], exe)
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册