未验证 提交 486f84ff 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #990 from guoshengCS/refine-transformer-wmt14

Refine Transformer for wmt14_en-de
class TrainTaskConfig(object):
# only support GPU currently
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
......@@ -13,8 +15,6 @@ class TrainTaskConfig(object):
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 4000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = True
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
......@@ -38,22 +38,20 @@ class InferTaskConfig(object):
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_length = 30
max_length = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
output_unk = True
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object):
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
......@@ -68,13 +66,13 @@ class ModelHyperParams(object):
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
......@@ -85,6 +83,9 @@ class ModelHyperParams(object):
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
def merge_cfg_from_list(cfg_list, g_cfgs):
......@@ -97,7 +98,7 @@ def merge_cfg_from_list(cfg_list, g_cfgs):
if hasattr(g_cfg, key):
try:
value = eval(value)
except SyntaxError: # for file path
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
......@@ -172,6 +173,10 @@ input_descs = {
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
......
......@@ -308,7 +308,7 @@ def infer(args):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
......@@ -317,7 +317,7 @@ def infer(args):
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.weight_sharing)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
......@@ -359,6 +359,7 @@ def infer(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
......
......@@ -46,26 +46,14 @@ def multi_head_attention(queries,
"""
q = layers.fc(input=queries,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_value,
fan_out=n_head * d_value),
bias_attr=False,
num_flatten_dims=2)
return q, k, v
......@@ -84,7 +72,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, -1, n_head, hidden_size // n_head])
x=x, shape=[0, 0, n_head, hidden_size // n_head])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -104,7 +92,7 @@ def multi_head_attention(queries,
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
......@@ -140,7 +128,6 @@ def multi_head_attention(queries,
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2)
return proj_out
......@@ -155,14 +142,8 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_hid**-0.5), high=(d_hid**-0.5)),
act="relu")
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
......@@ -200,6 +181,7 @@ def prepare_encoder(src_word,
src_max_len,
dropout_rate=0.,
src_data_shape=None,
word_emb_param_name=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
......@@ -209,7 +191,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
param_attr=fluid.initializer.Normal(0., 1.))
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
......@@ -415,7 +400,12 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
label_smooth_eps, ):
if weight_sharing:
assert src_vocab_size == src_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
enc_inputs = make_all_inputs(encoder_data_input_fields +
encoder_util_input_fields)
......@@ -429,6 +419,7 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
enc_inputs, )
dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] +
......@@ -444,6 +435,7 @@ def transformer(
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs,
enc_output, )
......@@ -459,7 +451,6 @@ def transformer(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
# cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
......@@ -476,6 +467,7 @@ def wrap_encoder(src_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
......@@ -497,7 +489,8 @@ def wrap_encoder(src_vocab_size,
d_model,
max_length,
dropout_rate,
src_data_shape, )
src_data_shape,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
......@@ -522,6 +515,7 @@ def wrap_decoder(trg_vocab_size,
d_model,
d_inner_hid,
dropout_rate,
weight_sharing,
dec_inputs=None,
enc_output=None):
"""
......@@ -547,7 +541,9 @@ def wrap_decoder(trg_vocab_size,
d_model,
max_length,
dropout_rate,
trg_data_shape, )
trg_data_shape,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
dec_output = decoder(
dec_input,
enc_output,
......@@ -565,11 +561,20 @@ def wrap_decoder(trg_vocab_size,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
# Return logits for training and probs for inference.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
if weight_sharing:
predict = layers.reshape(
x=layers.matmul(
x=dec_output,
y=fluid.get_var(word_emb_param_names[0]),
transpose_y=True),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
else:
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
return predict
......@@ -43,9 +43,11 @@ def parse_args():
parser.add_argument(
"--batch_size",
type=int,
default=2000,
default=2048,
help="The number of sequences contained in a mini-batch, or the maximum "
"number of tokens (include paddings) contained in a mini-batch.")
"number of tokens (include paddings) contained in a mini-batch. Note "
"that this represents the number on single device and the actual batch "
"size for multi-devices will multiply the device number.")
parser.add_argument(
"--pool_size",
type=int,
......@@ -203,50 +205,50 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
[num_token], dtype="float32")
def train(args):
dev_count = fluid.core.get_cuda_device_count()
def read_multiple(reader, count, clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def read_multiple(reader,
count=dev_count if args.use_token_batch else 1,
clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def split_data(data, num_part=dev_count):
"""
Split data for each device.
"""
if len(data) == num_part:
return data
data = data[0]
inst_num_per_part = len(data) // num_part
return [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(num_part)
]
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def split_data(data, num_part):
"""
Split data for each device.
"""
if len(data) == num_part:
return data
data = data[0]
inst_num_per_part = len(data) // num_part
return [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(num_part)
]
def train(args):
dev_count = fluid.core.get_cuda_device_count()
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
......@@ -254,7 +256,7 @@ def train(args):
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
TrainTaskConfig.label_smooth_eps)
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps,
......@@ -288,9 +290,12 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
train_data = read_multiple(
reader=train_data.batch_generator,
count=dev_count if args.use_token_batch else 1)
train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
......@@ -303,9 +308,11 @@ def train(args):
def test_context():
# Context to do validation.
test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
test_program = fluid.default_main_program().clone(for_test=True)
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
......@@ -319,22 +326,22 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
def test(exe=test_exe):
test_total_cost = 0
test_total_token = 0
test_data = read_multiple(reader=val_data.batch_generator)
test_data = read_multiple(
reader=val_data.batch_generator,
count=dev_count if args.use_token_batch else 1)
for batch_id, data in enumerate(test_data()):
feed_list = []
for place_id, data_buffer in enumerate(split_data(data)):
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
......@@ -367,7 +374,9 @@ def train(args):
feed_list = []
total_num_token = 0
lr_rate = lr_scheduler.update_learning_rate()
for place_id, data_buffer in enumerate(split_data(data)):
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
......@@ -377,17 +386,14 @@ def train(args):
dict(data_input_dict.items() + util_input_dict.items() +
{lr_scheduler.learning_rate.name: lr_rate}.items()))
if not init:
if not init: # init the position encoding table
for pos_enc_param_name in pos_enc_param_names:
pos_enc = position_encoding_init(
ModelHyperParams.max_length + 1,
ModelHyperParams.d_model)
feed_list[place_id][pos_enc_param_name] = pos_enc
for feed_dict in feed_list:
feed_dict[
sum_cost.name +
"@GRAD"] = 1. / total_num_token if TrainTaskConfig.use_avg_cost else np.asarray(
[1.], dtype="float32")
feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
feed=feed_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册