未验证 提交 64eb26ae 编写于 作者: G Guo Sheng 提交者: GitHub

Add validation for dygraph Transformer. (#4628)

Add cross-attention cache for dygraph Transformer.
Add greedy search for dygraph Transformer.
上级 2746e74b
......@@ -76,6 +76,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096
```
......@@ -91,6 +92,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
......@@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式
```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--use_cuda True \
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# the epoch number to train.
pass_num = 20
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
class InferTaskConfig(object):
# the number of examples in one run for sequence generation.
batch_size = 4
# the parameters for beam search.
beam_size = 4
alpha = 0.6
# max decoded length, should be less than ModelHyperParams.max_length
max_out_len = 30
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "init_score",
# "init_idx",
"trg_src_attn_bias", )
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
此差异已折叠。
......@@ -24,7 +24,6 @@ import paddle.fluid as fluid
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from utils.load import load_dygraph
# include task-specific libs
import reader
......@@ -58,6 +57,25 @@ def do_train(args):
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
val_batch_generator = val_processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
......@@ -74,6 +92,9 @@ def do_train(args):
# define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place)
if args.validation_file:
val_loader = fluid.io.DataLoader.from_generator(capacity=10)
val_loader.set_batch_generator(val_batch_generator, places=place)
# define model
transformer = Transformer(
......@@ -98,13 +119,13 @@ def do_train(args):
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
model_dict, opt_dict = load_dygraph(
model_dict, opt_dict = fluid.load_dygraph(
os.path.join(args.init_from_checkpoint, "transformer"))
transformer.load_dict(model_dict)
optimizer.set_dict(opt_dict)
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
model_dict, _ = load_dygraph(
model_dict, _ = fluid.load_dygraph(
os.path.join(args.init_from_pretrain_model, "transformer"))
transformer.load_dict(model_dict)
......@@ -174,13 +195,38 @@ def do_train(args):
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
ce_ppl.append(np.exp([min(total_avg_cost, 100)]))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0 and (
trainer_count == 1
or fluid.dygraph.parallel.Env().dev_id == 0):
if args.save_model:
if step_idx % args.save_step == 0 and step_idx != 0:
# validation
if args.validation_file:
transformer.eval()
total_sum_cost = 0
total_token_num = 0
for input_data in val_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
lbl_word, lbl_weight) = input_data
logits = transformer(src_word, src_pos,
src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias,
trg_src_attn_bias)
sum_cost, avg_cost, token_num = criterion(
logits, lbl_word, lbl_weight)
total_sum_cost += sum_cost.numpy()
total_token_num += token_num.numpy()
total_avg_cost = total_sum_cost / total_token_num
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.train()
if args.save_model and (
trainer_count == 1
or fluid.dygraph.parallel.Env().dev_id == 0):
model_dir = os.path.join(args.save_model,
"step_" + str(step_idx))
if not os.path.exists(model_dir):
......
......@@ -19,6 +19,8 @@ inference_model_dir: "infer_model"
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册