提交 b34e8c0e 编写于 作者: G guosheng

Skip the part batches in Transformer when training

上级 91a0b7c6
...@@ -43,9 +43,10 @@ class InferTaskConfig(object): ...@@ -43,9 +43,10 @@ class InferTaskConfig(object):
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has # <unk> token has alreay been added. As for the <pad> token, any token
# alreay been added. # included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary. # size of source word dictionary.
src_vocab_size = 10000 src_vocab_size = 10000
...@@ -61,9 +62,9 @@ class ModelHyperParams(object): ...@@ -61,9 +62,9 @@ class ModelHyperParams(object):
unk_idx = 2 unk_idx = 2
# max length of sequences. # max length of sequences.
# The size of position encoding table should plus 1, since the sinusoid # The size of position encoding table should at least plus 1, since the
# position encoding start from 1 and 0 can be used as the padding token # sinusoid position encoding starts from 1 and 0 can be used as the padding
# for position encoding. # token for position encoding.
max_length = 50 max_length = 50
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
......
...@@ -476,12 +476,16 @@ def make_inputs(input_data_names, ...@@ -476,12 +476,16 @@ def make_inputs(input_data_names,
append_batch_size=False) append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape] input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag: if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data( src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[2], shape=[2],
dtype="int32", dtype="int32",
append_batch_size=False) append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape] input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data( src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)], name=input_data_names[len(input_layers)],
shape=[4], shape=[4],
......
...@@ -180,6 +180,8 @@ def main(): ...@@ -180,6 +180,8 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx, label_data_names, ModelHyperParams.eos_idx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册