提交 67323767 编写于 作者: A AndyELiu 提交者: Yibing Liu

submit code for joint embedding paper (#2896)

上级 af1197ed
## 简介
### 任务说明
机器翻译的输入一般是源语言的句子。但在很多实际系统中,比如语音识别系统的输出或者基于拼音的文字输入,源语言句子一般包含很多同音字错误, 这会导致翻译出现很多意想不到的错误。由于可以同时获得发音信息,我们提出了一种在输入端加入发音信息,进而在模型的嵌入层
融合文字信息和发音信息的翻译方法,大大提高了翻译模型对同音字错误的抵抗能力。
文章地址:https://arxiv.org/abs/1810.06729
### 效果说明
我们使用LDC Chinese-to-English数据集训练。中文词典用的是[DaCiDian](https://github.com/aishell-foundation/DaCiDian)。 在newstest2006上进行评测,效果如下所示:
| beta=0 | beta=0.50 | beta=0.85 | beta=0.95 |
|-|-|-|-|
| 47.96 | 48.71 | 48.85 | 48.46 |
beta代表发音信息的权重。这表明,即使将绝大部分权重放在发音信息上,翻译的效果依然很好。与此同时,翻译系统对同音字错误的抵抗力大大提高。
## 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle Fluid 1.3.1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 环境依赖
请参考PaddlePaddle[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html)部分的内容
## 如何训练
1. 数据格式
数据格式和[Paddle机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer)的格式一致。为了获得输入句子的发音信息,需要额外提供源语言的发音基本单元和发音的词典。
A) 发音基本单元文件
中文的发音基本单元是拼音,将所有的拼音放在一个文件,类似:
<unk>
bo
li
。。。
B)发音词典
根据DaCiDian,对bpe后的源语言中的token赋予一个或者几个发音,类似:
▁玻利维亚 bo li wei ya
▁举行 ju xing
▁总统 zong tong
▁与 yu
巴斯 ba si
▁这个 zhei ge|zhe ge
。。。
2. 训练模型
数据准备完成后,可以使用 `train.py` 脚本进行训练。例子如下:
```sh
python train.py \
--src_vocab_fpath nist_data/vocab_all.28000 \
--trg_vocab_fpath nist_data/vocab_all.28000 \
--train_file_pattern nist_data/nist_train.txt \
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \
--lexicon_fpath nist_data/zh_lexicon.txt \
--batch_size 2048 \
--use_token_batch True \
--sort_type pool \
--pool_size 200000 \
--use_py_reader False \
--use_mem_opt False \
--enable_ce False \
--fetch_steps 1 \
pass_num 100 \
learning_rate 2.0 \
warmup_steps 8000 \
beta2 0.997 \
d_model 512 \
d_inner_hid 2048 \
n_head 8 \
weight_sharing True \
max_length 256 \
save_freq 10000 \
beta 0.85 \
model_dir pinyin_models_beta085 \
ckpt_dir pinyin_ckpts_beta085
```
上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符), 发音单元文件路径(`phoneme_vocab_fpath`), 发音词典路径(`lexicon_fpath`)等数据相关的参数和构造 batch 方式(`use_token_batch` 指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看:
```sh
python train.py --help
```
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置.
注意,如训练时更改了模型配置,使用 `infer.py` 预测时需要使用对应相同的模型配置;另外,训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用指定的 GPU。
## 如何预测
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出:
```sh
python infer.py \
--src_vocab_fpath nist_data/vocab_all.28000 \
--trg_vocab_fpath nist_data/vocab_all.28000 \
--test_file_pattern nist_data/nist_test.txt \
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \
--lexicon_fpath nist_data/zh_lexicon.txt \
--batch_size 32 \
model_path pinyin_models_beta085/iter_200000.infer.model \
beam_size 5 \
max_out_len 255 \
beta 0.85
```
class TrainTaskConfig(object):
# support both CPU and GPU now.
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0
# the frequency to save trained models.
save_freq = 10000
class InferTaskConfig(object):
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_out_len = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = True
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object):
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# size of phone dictionary
phone_vocab_size = 1000
# ratio of phoneme embeddings
beta = 0.0
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# index for <unk> in phonemes
phone_pad_idx = 0
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = 256
# The placeholder for phoneme sequence length in comiple time.
phone_len = 16
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
d_model = 512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
"src_phone": [(batch_size, seq_len, phone_len, 1), "int64"],
"src_phone_mask": [(batch_size, seq_len, phone_len), "int64"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table",
)
phone_emb_param_name = "phone_emb_table"
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table",
)
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_phone",
"src_phone_mask",
)
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output",
)
label_data_input_fields = (
"lbl_word",
"lbl_weight",
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"init_idx",
"trg_src_attn_bias",
)
# Set seed for CE
dropout_seed = None
import argparse
import ast
import multiprocessing
import numpy as np
import os
import sys
from functools import partial
import paddle
import paddle.fluid as fluid
import reader
from config import *
from desc import *
from model import fast_decode as fast_decoder
from train import pad_batch_data, pad_phoneme_data, prepare_data_generator
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of source language.")
parser.add_argument(
"--trg_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of target language.")
parser.add_argument(
"--phoneme_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of phonemes.")
parser.add_argument(
"--lexicon_fpath",
type=str,
required=True,
help="The path of lexicon of source language.")
parser.add_argument(
"--test_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--use_parallel_exe",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
# Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath)
dict_args = [
"src_vocab_size",
str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "phone_vocab_size",
str(len(phone_dict)), "bos_idx",
str(src_dict[args.special_token[0]]), "eos_idx",
str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
src_phone, src_phone_mask, max_phone_len = pad_phoneme_data(
[inst[1] for inst in insts], phone_pad_idx, src_max_len)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
data_tensor.set(data, place)
if lod is not None:
data_tensor.set_lod(lod)
return data_tensor
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(trg_word, dtype="float32").reshape(-1, 1), place,
[range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place,
[range(trg_word.shape[0] + 1)] * 2)
init_idx = np.asarray(range(len(insts)), dtype="int32")
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, src_phone, src_phone_mask,
trg_word, init_score, init_idx, trg_src_attn_bias
]))
return data_input_dict
def prepare_feed_dict_list(data_generator, count, place):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader, place):
"""
Data provider needed by fluid.layers.py_reader.
"""
def py_reader_provider():
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
for batch_id, data in enumerate(data_reader()):
data_input_dict = prepare_batch_input(
data, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
yield [data_input_dict[item] for item in data_input_names]
return py_reader_provider
def fast_infer(args):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
out_ids, out_scores, pyreader = fast_decoder(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.phone_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
InferTaskConfig.beam_size,
InferTaskConfig.max_out_len,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
beta=ModelHyperParams.beta,
use_py_reader=args.use_py_reader)
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().clone(for_test=True)
if args.use_mem_opt:
fluid.memory_optimize(infer_program)
if InferTaskConfig.use_gpu:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=[
var for var in infer_program.list_vars()
if isinstance(var, fluid.framework.Parameter)
])
exec_strategy = fluid.ExecutionStrategy()
# For faster executor
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
infer_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=infer_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
# data reader settings for inference
args.train_file_pattern = args.test_file_pattern
args.use_token_batch = False
args.sort_type = reader.SortType.NONE
args.shuffle = False
args.shuffle_batch = False
test_data = prepare_data_generator(
args,
is_test=False,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper,
place=place)
if args.use_py_reader:
pyreader.start()
data_generator = None
else:
data_generator = test_data()
trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
place)
if args.use_parallel_exe:
seq_ids, seq_scores = infer_exe.run(
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list,
return_numpy=False)
else:
seq_ids, seq_scores = exe.run(
program=infer_program,
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
seq_ids_list, seq_scores_list = [
seq_ids
], [seq_scores] if isinstance(
seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
for i in range(len(seq_ids.lod()[0]) -
1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx] for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
# The data pass is over.
if args.use_py_reader:
pyreader.reset()
break
if __name__ == "__main__":
args = parse_args()
fast_infer(args)
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from desc import *
def wrap_layer_with_block(layer, block_idx):
"""
Make layer define support indicating block, by which we can add layers
to other blocks within current block. This will make it easy to define
cache among while loop.
"""
class BlockGuard(object):
"""
BlockGuard class.
BlockGuard class is used to switch to the given block in a program by
using the Python `with` keyword.
"""
def __init__(self, block_idx=None, main_program=None):
self.main_program = fluid.default_main_program(
) if main_program is None else main_program
self.old_block_idx = self.main_program.current_block().idx
self.new_block_idx = block_idx
def __enter__(self):
self.main_program.current_block_idx = self.new_block_idx
def __exit__(self, exc_type, exc_val, exc_tb):
self.main_program.current_block_idx = self.old_block_idx
if exc_type is not None:
return False # re-raise exception
return True
def layer_wrapper(*args, **kwargs):
with BlockGuard(block_idx):
return layer(*args, **kwargs)
return layer_wrapper
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (
np.log(float(1e4) / float(1)) / (num_timescales - 1))
inv_timescales = np.exp(
np.arange(num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(
inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
gather_idx=None,
static_kv=False):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(
input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
fc_layer = wrap_layer_with_block(
layers.fc,
fluid.default_main_program().current_block().
parent_idx) if cache is not None and static_kv else layers.fc
k = fc_layer(
input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = fc_layer(
input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Reshape input tensors at the last dimension to split multi-heads
and then transpose. Specifically, transform the input tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] to the output tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped_q = layers.reshape(
x=queries, shape=[0, 0, n_head, d_key], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
reshape_layer = wrap_layer_with_block(
layers.reshape,
fluid.default_main_program().current_block().
parent_idx) if cache is not None and static_kv else layers.reshape
transpose_layer = wrap_layer_with_block(
layers.transpose,
fluid.default_main_program().current_block().parent_idx
) if cache is not None and static_kv else layers.transpose
reshaped_k = reshape_layer(
x=keys, shape=[0, 0, n_head, d_key], inplace=True)
k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = reshape_layer(
x=values, shape=[0, 0, n_head, d_value], inplace=True)
v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3])
if cache is not None: # only for faster inference
if static_kv: # For encoder-decoder attention in inference
cache_k, cache_v = cache["static_k"], cache["static_v"]
# To init the static_k and static_v in cache.
# Maybe we can use condition_op(if_else) to do these at the first
# step in while loop to replace these, however it might be less
# efficient.
static_cache_init = wrap_layer_with_block(
layers.assign,
fluid.default_main_program().current_block().parent_idx)
static_cache_init(k, cache_k)
static_cache_init(v, cache_v)
else: # For decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
# gather cell states corresponding to selected parent
select_k = layers.gather(cache_k, index=gather_idx)
select_v = layers.gather(cache_v, index=gather_idx)
if not static_kv:
# For self attention in inference, use cache and concat time steps.
select_k = layers.concat([select_k, k], axis=2)
select_v = layers.concat([select_v, v], axis=2)
# update cell states(caches) cached in global block
layers.assign(select_k, cache_k)
layers.assign(select_v, cache_v)
return q, select_k, select_v
return q, k, v
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(
input=out, size=d_model, bias_attr=False, num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(
input=x, size=d_inner_hid, num_flatten_dims=2, act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder(src_word,
src_pos,
src_vocab_size,
src_phone,
src_phone_mask,
phone_vocab_size,
src_emb_dim,
src_max_len,
beta=0.0,
dropout_rate=0.,
bos_idx=0,
phone_pad_idx=-1,
word_emb_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
# shape [batch_size, max_seq_len, max_phone_len, dim]
src_phone_emb = layers.embedding(
src_phone,
size=[phone_vocab_size, src_emb_dim],
padding_idx=phone_pad_idx, # set embedding of phone_pad_idx to 0
param_attr=fluid.ParamAttr(
name=phone_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
sum_phone_emb = layers.reduce_sum(src_phone_emb, dim=2)
float_mask = layers.cast(src_phone_mask, dtype='float32')
sum_mask = layers.reduce_sum(float_mask, dim=2) + 1e-9
mean_phone_emb = layers.elementwise_div(sum_phone_emb, sum_mask, axis=0)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_names[0], trainable=False))
src_pos_enc.stop_gradient = True
enc_input = (
1 - beta) * src_word_emb + beta * mean_phone_emb + src_pos_enc
return layers.dropout(
enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input
def prepare_decoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
bos_idx=0,
word_emb_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_names[1], trainable=False))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias,
d_key, d_value, d_model, n_head, attention_dropout)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
)
enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None,
gather_idx=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx)
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
postprocess_cmd,
prepostprocess_dropout,
)
enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx,
static_kv=True)
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
postprocess_cmd,
prepostprocess_dropout,
)
ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model,
relu_dropout,
)
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
)
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None,
gather_idx=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i],
gather_idx=gather_idx)
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[input_descs[input_field][0] for input_field in input_fields],
dtypes=[input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0
for input_field in input_fields
])
return layers.read_file(reader), reader
def transformer(src_vocab_size,
trg_vocab_size,
phone_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
label_smooth_eps,
beta,
bos_idx=0,
use_py_reader=False,
is_test=False):
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names,
is_test)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
label = all_inputs[-2]
weights = all_inputs[-1]
enc_output = wrap_encoder(
src_vocab_size,
phone_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beta,
enc_inputs,
)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs,
enc_output,
)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.label_smooth(
label=layers.one_hot(input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
def wrap_encoder(src_vocab_size,
phone_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beta,
enc_inputs=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
enc_inputs = make_all_inputs(encoder_data_input_fields)
src_word = enc_inputs[0]
src_pos = enc_inputs[1]
src_slf_attn_bias = enc_inputs[2]
src_phone = enc_inputs[3]
src_phone_mask = enc_inputs[4]
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
src_phone,
src_phone_mask,
phone_vocab_size,
d_model,
max_length,
beta,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
)
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=None,
enc_output=None,
caches=None,
gather_idx=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
dec_inputs = make_all_inputs(decoder_data_input_fields)
trg_word = dec_inputs[0]
trg_pos = dec_inputs[1]
trg_slf_attn_bias = dec_inputs[2]
trg_src_attn_bias = dec_inputs[3]
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches,
gather_idx=gather_idx)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
word_emb_param_names[0]),
transpose_y=True)
else:
predict = layers.fc(
input=dec_output, size=trg_vocab_size, bias_attr=False)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
return predict
def fast_decode(src_vocab_size,
trg_vocab_size,
phone_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
bos_idx,
eos_idx,
beta=0.0,
use_py_reader=False):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(
src_vocab_size,
phone_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beta,
enc_inputs,
bos_idx=bos_idx)
start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs
def beam_search():
max_len = layers.fill_constant(
shape=[1],
dtype=start_tokens.dtype,
value=max_out_len,
force_cpu=True)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True)
cond = layers.less_than(
x=step_idx, y=max_len) # default force_cpu=True
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps in decoder self-attention
# and static encoder output projections in encoder-decoder attention
# to reduce redundant computation.
caches = [
{
"k": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_key],
dtype=enc_output.dtype,
value=0),
"v": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_value],
dtype=enc_output.dtype,
value=0),
"static_k": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype),
"static_v": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype)
} for i in range(n_layer)
]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
pre_scores = layers.array_read(array=scores, i=step_idx)
# gather cell states corresponding to selected parent
pre_src_attn_bias = layers.gather(
trg_src_attn_bias, index=parent_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_src_attn_bias, # cann't use lod tensor here
value=1,
shape=[-1, 1, 1],
dtype=pre_ids.dtype),
y=step_idx,
axis=0)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=enc_output,
caches=caches,
gather_idx=parent_idx,
bos_idx=bos_idx)
# intra-beam topK
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches.
accu_scores = layers.lod_reset(accu_scores, pre_ids)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx,
return_parent_idx=True)
layers.increment(x=step_idx, value=1.0, in_place=True)
# cell states(caches) have been updated in wrap_decoder,
# only need to update beam search states here.
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(gather_idx, parent_idx)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores, reader if use_py_reader else None
import glob
import six
import os
import random
import tarfile
import numpy as np
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class SrcConverter(object):
def __init__(self, vocab, end, unk, delimiter, lexicon):
self._vocab = vocab
self._end = end
self._unk = unk
self._delimiter = delimiter
self._lexicon = lexicon
def __call__(self, sentence):
src_seqs = []
src_ph_seqs = []
unk_phs = self._lexicon['<unk>']
for w in sentence.split(self._delimiter):
src_seqs.append(self._vocab.get(w, self._unk))
ph_groups = self._lexicon.get(w, unk_phs)
src_ph_seqs.append(random.choice(ph_groups))
src_seqs.append(self._end)
src_ph_seqs.append(unk_phs[0])
return src_seqs, src_ph_seqs
class TgtConverter(object):
def __init__(self, vocab, beg, end, unk, delimiter):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
An example of returning a generator producing data batches whose data
is shuffled in each pass and sorted in each pool:
```
train_data = DataReader(
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
use_token_batch=True,
batch_size=2000,
pool_size=10000,
sort_type=SortType.POOL,
shuffle=True,
shuffle_batch=True,
start_mark='<s>',
end_mark='<e>',
unk_mark='<unk>',
clip_last_batch=False).batch_generator
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The size of pool buffer.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param seed: The seed for random.
:type seed: int
"""
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
phoneme_vocab_fpath,
lexicon_fpath,
batch_size,
pool_size,
sort_type=SortType.GLOBAL,
clip_last_batch=True,
tar_fname=None,
min_length=0,
max_length=100,
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._only_src = False
self._phoneme_vocab = self.load_dict(phoneme_vocab_fpath)
self._lexicon = self.load_lexicon(lexicon_fpath, self._phoneme_vocab)
self._pool_size = pool_size
self._batch_size = batch_size
self._use_token_batch = use_token_batch
self._sort_type = sort_type
self._clip_last_batch = clip_last_batch
self._shuffle = shuffle
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
self._random = np.random
self._random.seed(seed)
def load_lexicon(self, lexicon_path, phoneme_vocab):
lexicon = {}
with open(lexicon_path) as fp:
for line in fp:
tokens = line.strip().split()
word = tokens[0]
all_phone_str = ' '.join(tokens[1:])
phone_strs = all_phone_str.split('|')
phone_groups = []
for phone_str in phone_strs:
cur_phone_seq = [
phoneme_vocab[x] for x in phone_str.split()
]
phone_groups.append(cur_phone_seq)
lexicon[word] = phone_groups
lexicon['<unk>'] = [[phoneme_vocab['<unk>']]]
return lexicon
def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
unk_mark):
converters = [
SrcConverter(
vocab=self._src_vocab,
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark],
delimiter=self._token_delimiter,
lexicon=self._lexicon)
]
if not self._only_src:
converters.append(
TgtConverter(
vocab=self._trg_vocab,
beg=self._trg_vocab[start_mark],
end=self._trg_vocab[end_mark],
unk=self._trg_vocab[unk_mark],
delimiter=self._token_delimiter))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._src_phone_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0][0])
self._src_phone_ids.append(src_trg_ids[0][1])
lens = [len(src_trg_ids[0][0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode()
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode()
if reverse:
word_dict[idx] = line.strip("\n")
else:
word_dict[line.strip("\n")] = idx
return word_dict
def batch_generator(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
for batch in batches:
batch_ids = [info.i for info in batch]
if self._only_src:
yield [[(self._src_seq_ids[idx], self._src_phone_ids[idx])]
for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._src_phone_ids[idx],
self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
import argparse
import ast
import copy
import logging
import multiprocessing
import os
import six
import sys
import time
import numpy as np
import paddle.fluid as fluid
import reader
from config import *
from desc import *
from model import transformer, position_encoding_init
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of source language.")
parser.add_argument(
"--trg_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of target language.")
parser.add_argument(
"--phoneme_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of phonemes.")
parser.add_argument(
"--lexicon_fpath",
type=str,
required=True,
help="The path of lexicon of source language.")
parser.add_argument(
"--train_file_pattern",
type=str,
required=True,
help="The pattern to match training data files.")
parser.add_argument(
"--val_file_pattern",
type=str,
help="The pattern to match validation data files.")
parser.add_argument(
"--use_token_batch",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to "
"produce batch data according to token number.")
parser.add_argument(
"--batch_size",
type=int,
default=4096,
help="The number of sequences contained in a mini-batch, or the maximum "
"number of tokens (include paddings) contained in a mini-batch. Note "
"that this represents the number on single device and the actual batch "
"size for multi-devices will multiply the device number.")
parser.add_argument(
"--pool_size",
type=int,
default=200000,
help="The buffer size to pool data.")
parser.add_argument(
"--sort_type",
default="pool",
choices=("global", "pool", "none"),
help="The grain to sort by length: global for all instances; pool for "
"instances in pool; none for no sort.")
parser.add_argument(
"--shuffle",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle instances in each pass.")
parser.add_argument(
"--shuffle_batch",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument(
'--local',
type=ast.literal_eval,
default=True,
help='Whether to run as local mode.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--update_method',
choices=("pserver", "nccl2"),
default="pserver",
help='Update method.')
parser.add_argument(
'--sync', type=ast.literal_eval, default=True, help="sync mode.")
parser.add_argument(
"--enable_ce",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--fetch_steps",
type=int,
default=100,
help="The frequency to fetch and print output.")
args = parser.parse_args()
# Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath)
dict_args = [
"src_vocab_size",
str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "phone_vocab_size",
str(len(phone_dict)), "bos_idx",
str(src_dict[args.special_token[0]]), "eos_idx",
str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[TrainTaskConfig, ModelHyperParams])
return args
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint):
assert (trainer_id >= 0 and len(worker_endpoints) > 1
and current_endpoint in worker_endpoints)
eps = copy.deepcopy(worker_endpoints)
eps.remove(current_endpoint)
nccl_id_var = startup_prog.global_block().create_var(
name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
startup_prog.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": eps,
"trainer_id": trainer_id
})
return nccl_id_var
def pad_phoneme_data(phoneme_seqs, pad_idx, max_seq_len):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
ph_seq_lens = []
for ps in phoneme_seqs:
cur_seq_lens = [len(x) for x in ps]
ph_seq_lens.append(max(cur_seq_lens))
max_ph_seq_len = max(ph_seq_lens)
batch_size = len(phoneme_seqs)
phoneme_data = pad_idx * np.ones(
(batch_size, max_seq_len, max_ph_seq_len), dtype=np.int64)
phoneme_mask = np.zeros((batch_size, max_seq_len, max_ph_seq_len),
dtype=np.int64)
for i in range(batch_size):
cur_ph_seq = phoneme_seqs[i]
for j, cur_word_phs in enumerate(cur_ph_seq):
word_phs_len = len(cur_word_phs)
phoneme_data[i, j, :word_phs_len] = cur_word_phs
phoneme_mask[i, j, :word_phs_len] = 1
phoneme_data = np.reshape(phoneme_data, [batch_size, max_seq_len, -1, 1])
return phoneme_data, phoneme_mask, max_ph_seq_len
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len,
max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array(
[[0] * len(inst) + [-1e9] * (max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx,
trg_pad_idx, n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
src_phone, src_phone_mask, max_phone_len = pad_phoneme_data(
[inst[1] for inst in insts], phone_pad_idx, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[2] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[3] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, src_phone, src_phone_mask,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
lbl_weight
]))
return data_input_dict, np.asarray([num_token], dtype="float32")
def prepare_data_generator(args,
is_test,
count,
pyreader,
py_reader_provider_wrapper,
place=None):
"""
Data generator wrapper for DataReader. If use py_reader, set the data
provider for py_reader
"""
data_reader = reader.DataReader(
phoneme_vocab_fpath=args.phoneme_vocab_fpath,
lexicon_fpath=args.lexicon_fpath,
fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else count),
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False).batch_generator
def stack(data_reader, count, clip_last=True):
def __impl__():
res = []
for item in data_reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def split(data_reader, count):
def __impl__():
for item in data_reader():
inst_num_per_part = len(item) // count
for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part *
(i + 1)]
return __impl__
if not args.use_token_batch:
# to make data on each device have similar token number
data_reader = split(data_reader, count)
if args.use_py_reader:
pyreader.decorate_tensor_provider(
py_reader_provider_wrapper(data_reader, place))
data_reader = None
else: # Data generator for multi-devices
data_reader = stack(data_reader, count)
return data_reader
def prepare_feed_dict_list(data_generator, init_flag, count):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_dict_list.append(data_input_dict)
if init_flag:
for idx in range(count):
pos_enc_tables = dict()
for pos_enc_param_name in pos_enc_param_names:
pos_enc_tables[pos_enc_param_name] = position_encoding_init(
ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
if len(feed_dict_list) <= idx:
feed_dict_list.append(pos_enc_tables)
else:
feed_dict_list[idx] = dict(
list(pos_enc_tables.items()) +
list(feed_dict_list[idx].items()))
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader, place):
"""
Data provider needed by fluid.layers.py_reader.
"""
def py_reader_provider():
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
for batch_id, data in enumerate(data_reader()):
data_input_dict, num_token = prepare_batch_input(
data, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
yield [data_input_dict[item] for item in data_input_names]
return py_reader_provider
def test_context(exe, train_exe, dev_count):
# Context to do validation.
test_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
test_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
beta=ModelHyperParams.beta,
is_test=True)
test_prog = test_prog.clone(for_test=True)
test_data = prepare_data_generator(
args,
is_test=True,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper)
exe.run(startup_prog) # to init pyreader for testing
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(
exe, TrainTaskConfig.ckpt_path, main_program=test_prog)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
share_vars_from=train_exe)
def test(exe=test_exe, pyreader=pyreader):
test_total_cost = 0
test_total_token = 0
if args.use_py_reader:
pyreader.start()
data_generator = None
else:
data_generator = test_data()
while True:
try:
feed_dict_list = prepare_feed_dict_list(
data_generator, False, dev_count)
outs = test_exe.run(
fetch_list=[sum_cost.name, token_num.name],
feed=feed_dict_list)
except (StopIteration, fluid.core.EOFException):
# The current pass is over.
if args.use_py_reader:
pyreader.reset()
break
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
return test
def train_loop(exe,
train_prog,
startup_prog,
dev_count,
sum_cost,
avg_cost,
token_num,
predict,
pyreader,
nccl2_num_trainers=1,
nccl2_trainer_id=0):
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
exe.run(startup_prog) # to init pyreader for training
logging.info("load checkpoint from {}".format(
TrainTaskConfig.ckpt_path))
fluid.io.load_persistables(
exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
else:
logging.info("init fluid.framework.default_startup_program")
exe.run(startup_prog)
logging.info("begin reader")
train_data = prepare_data_generator(
args,
is_test=False,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper)
# For faster executor
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
# build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
logging.info("begin executor")
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=avg_cost.name,
main_program=train_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
if args.val_file_pattern is not None:
test = test_context(exe, train_exe, dev_count)
# the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
(1. - TrainTaskConfig.label_smooth_eps)) +
TrainTaskConfig.label_smooth_eps *
np.log(TrainTaskConfig.label_smooth_eps /
(ModelHyperParams.trg_vocab_size - 1) + 1e-20))
step_idx = 0
init_flag = True
logging.info("begin train")
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
if args.use_py_reader:
pyreader.start()
data_generator = None
else:
data_generator = train_data()
batch_id = 0
while True:
try:
feed_dict_list = prepare_feed_dict_list(
data_generator, init_flag, dev_count)
outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name]
if step_idx % args.fetch_steps == 0 else [],
feed=feed_dict_list)
if step_idx % args.fetch_steps == 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array(
outs[1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
% (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)
]), args.fetch_steps /
(time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"latest.checkpoint"), train_prog)
fluid.io.save_params(
exe,
os.path.join(TrainTaskConfig.model_dir,
"iter_" + str(step_idx) + ".infer.model"),
train_prog)
init_flag = False
batch_id += 1
step_idx += 1
except (StopIteration, fluid.core.EOFException):
# The current pass is over.
if args.use_py_reader:
pyreader.reset()
break
time_consumed = time.time() - pass_start_time
# Validate and save the persistable.
if args.val_file_pattern is not None:
val_avg_cost, val_ppl = test()
logging.info(
"epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
" consumed %fs" % (pass_id, val_avg_cost, val_avg_cost -
loss_normalizer, val_ppl, time_consumed))
else:
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
if not args.enable_ce:
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"),
train_prog)
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
if args.val_file_pattern is not None:
print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
def train(args):
# priority: ENV > args > config
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
logging.info(args)
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = fluid.Executor(place)
train_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
train_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.phone_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
TrainTaskConfig.label_smooth_eps,
ModelHyperParams.beta,
ModelHyperParams.bos_idx,
use_py_reader=args.use_py_reader,
is_test=False)
optimizer = None
if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
logging.info("before adam")
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
else:
optimizer = fluid.optimizer.SGD(0.003)
optimizer.minimize(avg_cost)
if args.use_mem_opt:
fluid.memory_optimize(train_prog)
if args.local:
logging.info("local start_up:")
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader)
else:
if args.update_method == "nccl2":
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
port = os.getenv("PADDLE_PORT")
worker_ips = os.getenv("PADDLE_TRAINERS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
if trainer_id == 0:
logging.info("train_id == 0, sleep 60s")
time.sleep(60)
logging.info("trainers_num:{}".format(trainers_num))
logging.info("worker_endpoints:{}".format(worker_endpoints))
logging.info("current_endpoint:{}".format(current_endpoint))
append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint)
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader, trainers_num,
trainer_id)
return
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
logging.info("pserver_endpoints:{}".format(pserver_endpoints))
logging.info("current_endpoint:{}".format(current_endpoint))
logging.info("trainer_id:{}".format(trainer_id))
logging.info("pserver_ips:{}".format(pserver_ips))
logging.info("port:{}".format(port))
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
program=train_prog,
startup_program=startup_prog)
if training_role == "PSERVER":
logging.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
logging.critical("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program()
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader)
else:
logging.critical(
"environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)
if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
logging.getLogger().setLevel(logging.INFO)
args = parse_args()
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册