提交 b34e8c0e 编写于 作者: G guosheng

Skip the part batches in Transformer when training

上级 91a0b7c6
......@@ -43,9 +43,10 @@ class InferTaskConfig(object):
class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added.
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size = 10000
......@@ -61,9 +62,9 @@ class ModelHyperParams(object):
unk_idx = 2
# max length of sequences.
# The size of position encoding table should plus 1, since the sinusoid
# position encoding start from 1 and 0 can be used as the padding token
# for position encoding.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
......
......@@ -476,12 +476,16 @@ def make_inputs(input_data_names,
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
......
......@@ -180,6 +180,8 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册