提交 be5c8e56 编写于 作者: Y Yu Yang

Remove reshape op for embedding and softmax

上级 c34bb5f1
...@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length ...@@ -116,29 +116,23 @@ seq_len = ModelHyperParams.max_length
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_word": [(batch_size * seq_len, 1L), "int64", 2], "src_word": [(batch_size, seq_len, 1L), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size * max_src_len_in_batch, 1]
"src_pos": [(batch_size * seq_len, 1L), "int64"], "src_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(batch_size * seq_len, 1L), "int64", "trg_word": [(batch_size, seq_len, 1L), "int64",
2], # lod_level is only used in fast decoder. 2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(batch_size * seq_len, 1L), "int64"], "trg_pos": [(batch_size, seq_len, 1L), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
...@@ -151,18 +145,6 @@ input_descs = { ...@@ -151,18 +145,6 @@ input_descs = {
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference. # This input is used in independent decoder program for inference.
# The actual data shape of enc_output is: # The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model] # [batch_size, max_src_len_in_batch, d_model]
...@@ -193,22 +175,12 @@ encoder_data_input_fields = ( ...@@ -193,22 +175,12 @@ encoder_data_input_fields = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", ) "src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
decoder_data_input_fields = ( decoder_data_input_fields = (
"trg_word", "trg_word",
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"enc_output", ) "enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", ) "lbl_weight", )
...@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = ( ...@@ -218,6 +190,6 @@ fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"trg_src_attn_bias", ) "trg_src_attn_bias", )
fast_decoder_util_input_fields = decoder_util_input_fields + ( fast_decoder_util_input_fields = (
"trg_slf_attn_pre_softmax_shape_delta", "trg_slf_attn_pre_softmax_shape_delta",
"trg_slf_attn_post_softmax_shape_delta", ) "trg_slf_attn_post_softmax_shape_delta", )
...@@ -29,8 +29,6 @@ def multi_head_attention(queries, ...@@ -29,8 +29,6 @@ def multi_head_attention(queries,
d_model, d_model,
n_head=1, n_head=1,
dropout_rate=0., dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None,
cache=None): cache=None):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
...@@ -101,14 +99,9 @@ def multi_head_attention(queries, ...@@ -101,14 +99,9 @@ def multi_head_attention(queries,
""" """
scaled_q = layers.scale(x=q, scale=d_model**-0.5) scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape( if attn_bias:
x=layers.elementwise_add( product += attn_bias
x=product, y=attn_bias) if attn_bias else product, weights = layers.softmax(product)
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, weights,
...@@ -191,7 +184,6 @@ def prepare_encoder(src_word, ...@@ -191,7 +184,6 @@ def prepare_encoder(src_word,
src_emb_dim, src_emb_dim,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
src_data_shape=None,
word_emb_param_name=None, word_emb_param_name=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
...@@ -212,10 +204,6 @@ def prepare_encoder(src_word, ...@@ -212,10 +204,6 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[batch_size, seq_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout( return layers.dropout(
enc_input, enc_input,
dropout_prob=dropout_rate, dropout_prob=dropout_rate,
...@@ -236,18 +224,16 @@ def encoder_layer(enc_input, ...@@ -236,18 +224,16 @@ def encoder_layer(enc_input,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.):
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder. """The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization with the post_process_layer to add residual connection, layer normalization
and droput. and droput.
""" """
attn_output = multi_head_attention( attn_output = multi_head_attention(enc_input, enc_input, enc_input,
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan", attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate) dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
...@@ -262,25 +248,14 @@ def encoder(enc_input, ...@@ -262,25 +248,14 @@ def encoder(enc_input,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.):
pre_softmax_shape=None,
post_softmax_shape=None):
""" """
The encoder is composed of a stack of identical layers returned by calling The encoder is composed of a stack of identical layers returned by calling
encoder_layer. encoder_layer.
""" """
for i in range(n_layer): for i in range(n_layer):
enc_output = encoder_layer( enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
enc_input, d_model, d_inner_hid, dropout_rate)
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_input = enc_output enc_input = enc_output
return enc_output return enc_output
...@@ -295,10 +270,6 @@ def decoder_layer(dec_input, ...@@ -295,10 +270,6 @@ def decoder_layer(dec_input,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
cache=None): cache=None):
""" The layer to be stacked in decoder part. """ The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except The structure of this module is similar to that in the encoder part except
...@@ -314,8 +285,6 @@ def decoder_layer(dec_input, ...@@ -314,8 +285,6 @@ def decoder_layer(dec_input,
d_model, d_model,
n_head, n_head,
dropout_rate, dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
cache, ) cache, )
slf_attn_output = post_process_layer( slf_attn_output = post_process_layer(
dec_input, dec_input,
...@@ -331,9 +300,7 @@ def decoder_layer(dec_input, ...@@ -331,9 +300,7 @@ def decoder_layer(dec_input,
d_value, d_value,
d_model, d_model,
n_head, n_head,
dropout_rate, dropout_rate, )
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
enc_attn_output = post_process_layer( enc_attn_output = post_process_layer(
slf_attn_output, slf_attn_output,
enc_attn_output, enc_attn_output,
...@@ -362,10 +329,6 @@ def decoder(dec_input, ...@@ -362,10 +329,6 @@ def decoder(dec_input,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate=0., dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None,
caches=None): caches=None):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
...@@ -381,12 +344,7 @@ def decoder(dec_input, ...@@ -381,12 +344,7 @@ def decoder(dec_input,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate, )
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
None if caches is None else caches[i], )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
...@@ -425,8 +383,7 @@ def transformer( ...@@ -425,8 +383,7 @@ def transformer(
assert src_vocab_size == src_vocab_size, ( assert src_vocab_size == src_vocab_size, (
"Vocabularies in source and target should be same for weight sharing." "Vocabularies in source and target should be same for weight sharing."
) )
enc_inputs = make_all_inputs(encoder_data_input_fields + enc_inputs = make_all_inputs(encoder_data_input_fields)
encoder_util_input_fields)
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size,
...@@ -441,8 +398,7 @@ def transformer( ...@@ -441,8 +398,7 @@ def transformer(
weight_sharing, weight_sharing,
enc_inputs, ) enc_inputs, )
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] + dec_inputs = make_all_inputs(decoder_data_input_fields[:-1])
decoder_util_input_fields)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -466,8 +422,10 @@ def transformer( ...@@ -466,8 +422,10 @@ def transformer(
label=layers.one_hot( label=layers.one_hot(
input=label, depth=trg_vocab_size), input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps) epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=predict, logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
label=label, label=label,
soft_label=True if label_smooth_eps else False) soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights weighted_cost = cost * weights
...@@ -494,13 +452,11 @@ def wrap_encoder(src_vocab_size, ...@@ -494,13 +452,11 @@ def wrap_encoder(src_vocab_size,
""" """
if enc_inputs is None: if enc_inputs is None:
# This is used to implement independent encoder program in inference. # This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \ src_word, src_pos, src_slf_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_all_inputs(encoder_data_input_fields + make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields) encoder_util_input_fields)
else: else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \ src_word, src_pos, src_slf_attn_bias = \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
enc_inputs enc_inputs
enc_input = prepare_encoder( enc_input = prepare_encoder(
src_word, src_word,
...@@ -509,20 +465,9 @@ def wrap_encoder(src_vocab_size, ...@@ -509,20 +465,9 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
max_length, max_length,
dropout_rate, dropout_rate,
src_data_shape,
word_emb_param_name=word_emb_param_names[0]) word_emb_param_name=word_emb_param_names[0])
enc_output = encoder( enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
enc_input, d_value, d_model, d_inner_hid, dropout_rate)
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
return enc_output return enc_output
...@@ -545,15 +490,10 @@ def wrap_decoder(trg_vocab_size, ...@@ -545,15 +490,10 @@ def wrap_decoder(trg_vocab_size,
if dec_inputs is None: if dec_inputs is None:
# This is used to implement independent decoder program in inference. # This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
enc_output, trg_data_shape, slf_attn_pre_softmax_shape, \ enc_output = make_all_inputs(
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = make_all_inputs(
decoder_data_input_fields + decoder_util_input_fields) decoder_data_input_fields + decoder_util_input_fields)
else: else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
dec_input = prepare_decoder( dec_input = prepare_decoder(
trg_word, trg_word,
...@@ -562,7 +502,6 @@ def wrap_decoder(trg_vocab_size, ...@@ -562,7 +502,6 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
max_length, max_length,
dropout_rate, dropout_rate,
trg_data_shape,
word_emb_param_name=word_emb_param_names[0] word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1]) if weight_sharing else word_emb_param_names[1])
dec_output = decoder( dec_output = decoder(
...@@ -576,29 +515,20 @@ def wrap_decoder(trg_vocab_size, ...@@ -576,29 +515,20 @@ def wrap_decoder(trg_vocab_size,
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate, )
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape,
caches, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
if weight_sharing: if weight_sharing:
predict = layers.reshape( predict = layers.matmul(
x=layers.matmul(
x=dec_output, x=dec_output,
y=fluid.get_var(word_emb_param_names[0]), y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True), transpose_y=True)
shape=[-1, trg_vocab_size], predict = layers.softmax(predict)
act="softmax" if dec_inputs is None else None)
else: else:
predict = layers.reshape( predict = layers.fc(input=dec_output,
x=layers.fc(input=dec_output,
size=trg_vocab_size, size=trg_vocab_size,
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2,
shape=[-1, trg_vocab_size], act='softmax')
act="softmax" if dec_inputs is None else None)
return predict return predict
......
...@@ -180,34 +180,23 @@ def pad_batch_data(insts, ...@@ -180,34 +180,23 @@ def pad_batch_data(insts,
return return_list if len(return_list) > 1 else return_list[0] return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
trg_pad_idx, n_head, d_model): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
[-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
[-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
[-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
lbl_word, lbl_weight, num_token = pad_batch_data( lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts], [inst[2] for inst in insts],
trg_pad_idx, trg_pad_idx,
...@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -223,15 +212,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
])) ]))
util_input_dict = dict( return data_input_dict, np.asarray([num_token], dtype="float32")
zip(util_input_names, [
src_data_shape, src_slf_attn_pre_softmax_shape,
src_slf_attn_post_softmax_shape, trg_data_shape,
trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
]))
return data_input_dict, util_input_dict, np.asarray(
[num_token], dtype="float32")
def read_multiple(reader, count, clip_last=True): def read_multiple(reader, count, clip_last=True):
...@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -317,12 +298,11 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input( data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append( feed_list.append(data_input_dict)
dict(data_input_dict.items() + util_input_dict.items()))
outs = exe.run(feed=feed_list, outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name]) fetch_list=[sum_cost.name, token_num.name])
...@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -380,7 +360,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
data_input_names = encoder_data_input_fields + decoder_data_input_fields[: data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields -1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
if args.val_file_pattern is not None: if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count, test = test_context(train_progm, avg_cost, train_exe, dev_count,
...@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -404,13 +383,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input( data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_kv_pairs = data_input_dict.items() + util_input_dict.items( feed_kv_pairs = data_input_dict.items()
)
if args.local: if args.local:
feed_kv_pairs += { feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate lr_scheduler.learning_rate.name: lr_rate
...@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -460,7 +438,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe) data_input_names[:-2], [predict], exe)
if args.enable_ce: # For CE if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost)) print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册