提交 9828b49e 编写于 作者: G guosheng

Add label smoothing in Transformer

上级 d300e5e4
......@@ -18,6 +18,11 @@ class TrainTaskConfig(object):
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
......@@ -70,7 +75,6 @@ class ModelHyperParams(object):
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
......
......@@ -516,7 +516,8 @@ def transformer(
d_value,
d_model,
d_inner_hid,
dropout_rate, ):
dropout_rate,
label_smooth_eps, ):
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
......@@ -570,7 +571,7 @@ def transformer(
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(
label, weights = make_inputs(
label_data_names,
n_head,
d_model,
......@@ -582,7 +583,15 @@ def transformer(
data_shape_flag=False,
slf_attn_shape_flag=False,
src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
if label_smooth_eps:
label = layers.label_smooth(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
......
......@@ -120,7 +120,8 @@ def main():
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
TrainTaskConfig.label_smooth_eps)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册