提交 866d3e03 编写于 作者: G guosheng

Add validation for dygraph Transformer.

Add cross-attention cache for dygraph Transformer.
Add greedy search for dygraph Transformer.
上级 923722de
......@@ -76,6 +76,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096
```
......@@ -91,6 +92,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
......@@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式
```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--use_cuda True \
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# the epoch number to train.
pass_num = 20
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
class InferTaskConfig(object):
# the number of examples in one run for sequence generation.
batch_size = 4
# the parameters for beam search.
beam_size = 4
alpha=0.6
# max decoded length, should be less than ModelHyperParams.max_length
max_out_len = 30
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table",
)
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table",
)
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
)
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output",
)
label_data_input_fields = (
"lbl_word",
"lbl_weight",
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "init_score",
# "init_idx",
"trg_src_attn_bias",
)
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
......@@ -18,7 +18,6 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layers.utils import map_structure
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
......@@ -128,29 +127,46 @@ class MultiHeadAttention(Layer):
output_dim=d_model,
bias_attr=False)
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
keys, values = queries, queries
static_kv = False
else: # cross-attention
static_kv = True
q = self.q_fc(queries)
k = self.k_fc(keys)
v = self.v_fc(values)
# split head
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
if cache is not None and static_kv and cache.has_key("static_k"):
# for encoder-decoder attention in inference and has cached
k = cache["static_k"]
v = cache["static_v"]
else:
k = self.k_fc(keys)
v = self.v_fc(values)
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None:
if static_kv and not cache.has_key("static_k"):
# for encoder-decoder attention in inference and has not cached
cache["static_k"], cache["static_v"] = k, v
elif not static_kv:
# for decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
k = layers.concat([cache_k, k], axis=2)
v = layers.concat([cache_v, v], axis=2)
cache["k"], cache["v"] = k, v
return q, k, v
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention
product = layers.matmul(x=q,
y=k,
......@@ -381,7 +397,7 @@ class DecoderLayer(Layer):
cross_attn_output = self.cross_attn(
self.preprocesser2(self_attn_output), enc_output, enc_output,
cross_attn_bias)
cross_attn_bias, cache)
cross_attn_output = self.postprocesser2(cross_attn_output,
self_attn_output)
......@@ -810,6 +826,36 @@ class Transformer(Layer):
eos_id=1,
beam_size=4,
max_len=256):
if beam_size == 1:
return self._greedy_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
max_len=max_len)
else:
return self._beam_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_len=max_len)
def _beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
......@@ -822,22 +868,30 @@ class Transformer(Layer):
tensor.shape[2:])
def split_batch_beams(tensor):
return fluid.layers.reshape(tensor,
return layers.reshape(tensor,
shape=[-1, beam_size] +
list(tensor.shape[1:]))
def mask_probs(probs, finished, noend_mask_tensor):
# TODO: use where_op
finished = layers.cast(finished, dtype=probs.dtype)
probs = layers.elementwise_mul(
layers.expand(layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
noend_mask_tensor, axis=-1) - layers.elementwise_mul(probs, (finished - 1), axis=0)
probs = layers.elementwise_mul(layers.expand(
layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
noend_mask_tensor,
axis=-1) - layers.elementwise_mul(
probs, (finished - 1), axis=0)
return probs
def gather(x, indices, batch_pos):
topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2)
topk_coordinates = layers.stack([batch_pos, indices], axis=2)
return layers.gather_nd(x, topk_coordinates)
def update_states(func, caches):
for cache in caches: # no need to update static_kv
cache["k"] = func(cache["k"])
cache["v"] = func(cache["v"])
return caches
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
......@@ -893,13 +947,13 @@ class Transformer(Layer):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
caches = map_structure( # can not be reshaped since the 0 size
caches = update_states( # can not be reshaped since the 0 size
lambda x: x if i == 0 else merge_batch_beams(x), caches)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
caches = map_structure(split_batch_beams, caches)
caches = update_states(split_batch_beams, caches)
step_log_probs = split_batch_beams(
fluid.layers.log(fluid.layers.softmax(logits)))
layers.log(layers.softmax(logits)))
step_log_probs = mask_probs(step_log_probs, finished,
noend_mask_tensor)
log_probs = layers.elementwise_add(x=step_log_probs,
......@@ -908,15 +962,14 @@ class Transformer(Layer):
log_probs = layers.reshape(log_probs,
[-1, beam_size * self.trg_vocab_size])
scores = log_probs
topk_scores, topk_indices = fluid.layers.topk(input=scores,
k=beam_size)
beam_indices = fluid.layers.elementwise_floordiv(
topk_scores, topk_indices = layers.topk(input=scores, k=beam_size)
beam_indices = layers.elementwise_floordiv(
topk_indices, vocab_size_tensor)
token_indices = fluid.layers.elementwise_mod(
token_indices = layers.elementwise_mod(
topk_indices, vocab_size_tensor)
# update states
caches = map_structure(lambda x: gather(x, beam_indices, batch_pos),
caches = update_states(lambda x: gather(x, beam_indices, batch_pos),
caches)
log_probs = gather(log_probs, topk_indices, batch_pos)
finished = gather(finished, beam_indices, batch_pos)
......@@ -937,3 +990,72 @@ class Transformer(Layer):
finished_scores = topk_scores
return finished_seq, finished_scores
def _greedy_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
max_len=256):
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
end_token_tensor = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=eos_id)
predict_ids = []
log_probs = layers.fill_constant(shape=[batch_size, 1],
dtype="float32",
value=0)
trg_word = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=bos_id)
## init states (caches) for transformer
caches = [{
"k":
layers.fill_constant(
shape=[batch_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
step_log_probs = layers.log(layers.softmax(logits))
log_probs = layers.elementwise_add(x=step_log_probs,
y=log_probs,
axis=0)
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=1)
finished = layers.equal(topk_indices, end_token_tensor)
trg_word = topk_indices
log_probs = topk_scores
predict_ids.append(topk_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
finished_seq = layers.transpose(predict_ids, [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
......@@ -57,6 +57,25 @@ def do_train(args):
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
val_batch_generator = val_processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
......@@ -73,6 +92,9 @@ def do_train(args):
# define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place)
if args.validation_file:
val_loader = fluid.io.DataLoader.from_generator(capacity=10)
val_loader.set_batch_generator(val_batch_generator, places=place)
# define model
transformer = Transformer(
......@@ -123,6 +145,7 @@ def do_train(args):
# train loop
for pass_id in range(args.epoch):
pass_start_time = time.time()
avg_batch_time = time.time()
batch_id = 0
for input_data in train_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
......@@ -155,7 +178,6 @@ def do_train(args):
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
......@@ -164,12 +186,37 @@ def do_train(args):
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0 and (
if step_idx % args.save_step == 0 and step_idx != 0:
# validation
if args.validation_file:
transformer.eval()
total_sum_cost = 0
total_token_num = 0
for input_data in val_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
lbl_word, lbl_weight) = input_data
logits = transformer(src_word, src_pos,
src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias,
trg_src_attn_bias)
sum_cost, avg_cost, token_num = criterion(
logits, lbl_word, lbl_weight)
total_sum_cost += sum_cost.numpy()
total_token_num += token_num.numpy()
total_avg_cost = total_sum_cost / total_token_num
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.train()
if args.save_model and (
trainer_count == 1
or fluid.dygraph.parallel.Env().dev_id == 0):
if args.save_model:
model_dir = os.path.join(args.save_model,
"step_" + str(step_idx))
if not os.path.exists(model_dir):
......@@ -181,6 +228,7 @@ def do_train(args):
optimizer.state_dict(),
os.path.join(model_dir, "transformer"))
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
......
......@@ -19,6 +19,8 @@ inference_model_dir: "infer_model"
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册