提交 2a0991e1 编写于 作者: G guosheng

Transformer with hapi

上级 71075698
## Transformer
以下是本例的简要目录结构及说明:
```text
.
├── images # README 文档中的图片
├── utils # 工具包
├── gen_data.sh # 数据生成脚本
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
├── model.py # 模型定义文件
└── transformer.yaml # 配置文件
```
## 模型简介
机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。
本项目是机器翻译领域主流模型 Transformer 的 PaddlePaddle 实现, 包含模型训练,预测以及使用自定义数据等内容。用户可以基于发布的内容搭建自己的翻译模型。
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 下载代码
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/dygraph/transformer
```
3. 环境依赖
请参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.6/beginners_guide/install/index_cn.html)部分的内容
### 数据准备
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例提供。运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理(时间较长,建议后台运行)。数据处理过程主要包括 Tokenize 和 [BPE 编码(byte-pair encoding)](https://arxiv.org/pdf/1508.07909)。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下:
```text
.
├── wmt16_ende_data # WMT16 英德翻译数据
├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
└── subword-nmt # BPE 编码的代码
```
另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用,其中包含词典(`vocab_all.bpe.32000`文件)、训练所需的 BPE 数据(`train.tok.clean.bpe.32000.en-de`文件)、预测所需的 BPE 数据(`newstest2016.tok.bpe.32000.en-de`等文件)和相应的评估预测结果所需的 tokenize 数据(`newstest2016.tok.de`等文件)。
自定义数据:如果需要使用自定义数据,本项目程序中可直接支持的数据格式为制表符 \t 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔。提供以上格式的数据文件(可以分多个part,数据读取支持文件通配符)和相应的词典文件即可直接运行。
### 单机训练
### 单机单卡
以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096
```
以上命令中传入了训练轮数(`epoch`)和训练数据文件路径(注意请正确设置,支持通配符)等参数,更多参数的使用以及支持的模型超参数可以参见 `transformer.yaml` 配置文件,其中默认提供了 Transformer base model 的配置,如需调整可以在配置文件中更改或通过命令行传入(命令行传入内容将覆盖配置文件中的设置)。可以通过以下命令来训练 Transformer 论文中的 big model:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `transformer.pdparams``transformer.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
```txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
```
也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。
#### 单机多卡
Paddle动态图支持多进程多卡进行模型训练,启动训练的方式如下:
```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--use_cuda True \
--save_step 10000
```
此时,程序会将每个进程的输出log导入到`./mylog`路径下,只有第一个工作进程会保存模型。
```
.
├── mylog
│   ├── workerlog.0
│   ├── workerlog.1
│   ├── workerlog.2
│   ├── workerlog.3
│   ├── workerlog.4
│   ├── workerlog.5
│   ├── workerlog.6
│   └── workerlog.7
```
### 模型推断
以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt
```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
### 模型评估
预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标):
```sh
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果:
```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
```
使用本项目中提供的内容,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试有大约如下的 BLEU 值:
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| Base | 26.35 | 29.07 | 33.30 |
| Big | 27.07 | 30.09 | 34.38 |
### 预训练模型
我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_dygraph.tar.gz)[big model](https://transformer-res.bj.bcebos.com/big_model_dygraph.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。
## 进阶使用
### 背景介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构,其完全使用注意力(Attention)机制来实现序列到序列的建模[1]。
相较于此前 Seq2Seq 模型中广泛使用的循环神经网络(Recurrent Neural Network, RNN),使用(Self)Attention 进行输入序列到输出序列的变换主要具有以下优势:
- 计算复杂度小
- 特征维度为 d 、长度为 n 的序列,在 RNN 中计算复杂度为 `O(n * d * d)` (n 个时间步,每个时间步计算 d 维的矩阵向量乘法),在 Self-Attention 中计算复杂度为 `O(n * n * d)` (n 个时间步两两计算 d 维的向量点积或其他相关度函数),n 通常要小于 d 。
- 计算并行度高
- RNN 中当前时间步的计算要依赖前一个时间步的计算结果;Self-Attention 中各时间步的计算只依赖输入不依赖之前时间步输出,各时间步可以完全并行。
- 容易学习长程依赖(long-range dependencies)
- RNN 中相距为 n 的两个位置间的关联需要 n 步才能建立;Self-Attention 中任何两个位置都直接相连;路径越短信号传播越容易。
Transformer 中引入使用的基于 Self-Attention 的序列建模模块结构,已被广泛应用在 Bert [2]等语义表示模型中,取得了显著效果。
### 模型概览
Transformer 同样使用了 Seq2Seq 模型中典型的编码器-解码器(Encoder-Decoder)的框架结构,整体网络结构如图1所示。
<p align="center">
<img src="images/transformer_network.png" height=400 hspace='10'/> <br />
图 1. Transformer 网络结构图
</p>
可以看到,和以往 Seq2Seq 模型不同,Transformer 的 Encoder 和 Decoder 中不再使用 RNN 的结构。
### 模型特点
Transformer 中的 Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
此外,每个 sub-layer 后还施以 Residual Connection [3]和 Layer Normalization [4]来促进梯度传播和模型收敛。
<p align="center">
<img src="images/multi_head_attention.png" height=300 hspace='10'/> <br />
图 2. Multi-Head Attention
</p>
Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
## FAQ
**Q:** 预测结果中样本数少于输入的样本数是什么原因
**A:** 若样本中最大长度超过 `transformer.yaml``max_length` 的默认设置,请注意运行时增大 `--max_length` 的设置,否则超长样本将被过滤。
**Q:** 预测时最大长度超过了训练时的最大长度怎么办
**A:** 由于训练时 `max_length` 的设置决定了保存模型 position encoding 的大小,若预测时长度超过 `max_length`,请调大该值,会重新生成更大的 position encoding 表。
## 参考文献
1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
2. Devlin J, Chang M W, Lee K, et al. [Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805)[J]. arXiv preprint arXiv:1810.04805, 2018.
3. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
4. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
5. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
## 作者
- [guochengCS](https://github.com/guoshengCS)
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from model import Transformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
# define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
device_count=1,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict", place=place)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
with fluid.dygraph.guard(place):
# define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10)
test_loader.set_batch_generator(batch_generator, places=place)
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
model_dict, _ = fluid.load_dygraph(
os.path.join(args.init_from_params, "transformer"))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
transformer.load_dict(model_dict)
# set evaluate mode
transformer.eval()
f = open(args.output_file, "wb")
for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data
finished_seq, finished_scores = transformer.beam_search(
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=args.bos_idx,
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_len=args.max_out_len)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_predict(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import six
import os
import tarfile
import numpy as np
import paddle.fluid as fluid
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
]
return data_inputs
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataProcessor(object):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
An example of returning a generator producing data batches whose data
is shuffled in each pass and sorted in each pool:
```
train_data = DataProcessor(
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
use_token_batch=True,
batch_size=2000,
device_count=8,
n_head=8,
pool_size=10000,
sort_type=SortType.POOL,
shuffle=True,
shuffle_batch=True,
start_mark='<s>',
end_mark='<e>',
unk_mark='<unk>',
clip_last_batch=False).data_generator(phase='train')
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The size of pool buffer.
:type device_count: int
:param device_count: The number of devices. The actual batch size is
determined by both batch_size and device_count.
:type n_head: int
:param n_head: The number of head used in multi-head attention. Actually,
this is not a reader related argument, but is used for input data.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param only_src: Whether each line is a source and target sentence
pair or only has the source sentence.
:type only_src: bool
:param seed: The seed for random.
:type seed: int
"""
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
batch_size,
device_count,
n_head,
pool_size,
sort_type=SortType.GLOBAL,
clip_last_batch=False,
tar_fname=None,
min_length=0,
max_length=100,
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
only_src=False,
seed=0):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._pool_size = pool_size
self._batch_size = batch_size
self._device_count = device_count
self._n_head = n_head
self._use_token_batch = use_token_batch
self._sort_type = sort_type
self._clip_last_batch = clip_last_batch
self._shuffle = shuffle
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname)
self._random = np.random
self._random.seed(seed)
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip(b"\n")] = idx
return word_dict
def batch_generator(self, batch_size, use_token_batch):
def __impl__():
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch
batches = []
batch_creator = TokenBatchCreator(
batch_size) if use_token_batch else SentenceBatchCreator(
batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
for batch in batches:
batch_ids = [info.i for info in batch]
if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
return __impl__
@staticmethod
def stack(data_reader, count, clip_last=True):
def __impl__():
res = []
for item in data_reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
@staticmethod
def split(data_reader, count):
def __impl__():
for item in data_reader():
inst_num_per_part = len(item) // count
for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part *
(i + 1)]
return __impl__
def data_generator(self, phase, place=None):
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
src_pad_idx = trg_pad_idx = self._eos_idx
bos_idx = self._bos_idx
n_head = self._n_head
data_reader = self.batch_generator(
self._batch_size *
(1 if self._use_token_batch else self._device_count),
self._use_token_batch)
if not self._use_token_batch:
# to make data on each device have similar token number
data_reader = self.split(data_reader, self._device_count)
def __for_train__():
for data in data_reader():
data_inputs = prepare_train_input(data, src_pad_idx,
trg_pad_idx, n_head)
yield data_inputs
def __for_predict__():
for data in data_reader():
data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx,
n_head, place)
yield data_inputs
return __for_train__ if phase == "train" else __for_predict__
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
def do_train(args):
device_ids = list(range(args.num_devices))
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=args.num_devices,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=args.num_devices,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
val_batch_generator = val_processor.data_generator(phase="train")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
with guard:
# set seed for CE
random_seed = eval(str(args.random_seed))
if random_seed is not None:
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define data loader
train_loader = batch_generator
if args.validation_file:
val_loader = val_batch_generator
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(
args.d_model, args.warmup_steps), # args.learning_rate),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps))
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
transformer.load(
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"))
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) +
args.label_smooth_eps * np.log(args.label_smooth_eps /
(args.trg_vocab_size - 1) + 1e-20))
step_idx = 0
# train loop
for pass_id in range(args.epoch):
pass_start_time = time.time()
batch_id = 0
for input_data in train_loader():
outputs, losses = transformer.train(input_data[:-2],
input_data[-2:],
device='gpu',
device_ids=device_ids)
if step_idx % args.print_step == 0:
total_avg_cost = np.sum(losses)
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
# validation: how to accumulate with Model loss
if args.validation_file:
total_avg_cost = 0
for idx, input_data in enumerate(val_loader()):
outputs, losses = transformer.eval(
input_data[:-2],
input_data[-2:],
device='gpu',
device_ids=device_ids)
total_avg_cost += np.sum(losses)
total_avg_cost /= idx + 1
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.save(
os.path.join(args.save_model,
"step_" + str(step_idx),
"transformer"))
batch_id += 1
step_idx += 1
time_consumed = time.time() - pass_start_time
if args.save_model:
transformer.save(
os.path.join(args.save_model, "step_final", "transformer"))
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_train(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from model import Model, shape_hints, CrossEntropy, Loss
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(
np.arange(num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(
inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
class NoamDecay(LearningRateDecay):
"""
learning rate scheduler
"""
def __init__(self,
d_model,
warmup_steps,
static_lr=2.0,
begin=1,
step=1,
dtype='float32'):
super(NoamDecay, self).__init__(begin, step, dtype)
self.d_model = d_model
self.warmup_steps = warmup_steps
self.static_lr = static_lr
def step(self):
a = self.create_lr_var(self.step_num**-0.5)
b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
lr_value = (self.d_model**-0.5) * layers.elementwise_min(
a, b) * self.static_lr
return lr_value
class PrePostProcessLayer(Layer):
"""
PrePostProcessLayer
"""
def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
self.functors = []
for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x)
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" %
len(self.sublayers(include_sublayers=False)),
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout
self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False)
if dropout_rate else x)
def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd):
if cmd == "a":
x = self.functors[i](x, residual)
else:
x = self.functors[i](x)
return x
class MultiHeadAttention(Layer):
"""
Multi-Head Attention
"""
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.k_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.v_fc = Linear(input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
keys, values = queries, queries
static_kv = False
else: # cross-attention
static_kv = True
q = self.q_fc(queries)
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
if cache is not None and static_kv and "static_k" in cache:
# for encoder-decoder attention in inference and has cached
k = cache["static_k"]
v = cache["static_v"]
else:
k = self.k_fc(keys)
v = self.v_fc(values)
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None:
if static_kv and not "static_k" in cache:
# for encoder-decoder attention in inference and has not cached
cache["static_k"], cache["static_v"] = k, v
elif not static_kv:
# for decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
k = layers.concat([cache_k, k], axis=2)
v = layers.concat([cache_v, v], axis=2)
cache["k"], cache["v"] = k, v
return q, k, v
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention
product = layers.matmul(x=q,
y=k,
transpose_y=True,
alpha=self.d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
weights = layers.dropout(weights,
dropout_prob=self.dropout_rate,
is_test=False)
out = layers.matmul(weights, v)
# combine heads
out = layers.transpose(out, perm=[0, 2, 1, 3])
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.proj_fc(out)
return out
class FFN(Layer):
"""
Feed-Forward Network
"""
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x):
hidden = self.fc1(x)
if self.dropout_rate:
hidden = layers.dropout(hidden,
dropout_prob=self.dropout_rate,
is_test=False)
out = self.fc2(hidden)
return out
class EncoderLayer(Layer):
"""
EncoderLayer
"""
def __init__(self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(self.preprocesser1(enc_input), None, None,
attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output))
ffn_output = self.postprocesser2(ffn_output, attn_output)
return ffn_output
class Encoder(Layer):
"""
encoder
"""
def __init__(self,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(Encoder, self).__init__()
self.encoder_layers = list()
for i in range(n_layer):
self.encoder_layers.append(
self.add_sublayer(
"layer_%d" % i,
EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
for encoder_layer in self.encoder_layers:
enc_output = encoder_layer(enc_input, attn_bias)
enc_input = enc_output
return self.processer(enc_output)
class Embedder(Layer):
"""
Word Embedding + Position Encoding
"""
def __init__(self, vocab_size, emb_dim, bos_idx=0):
super(Embedder, self).__init__()
self.word_embedder = Embedding(
size=[vocab_size, emb_dim],
padding_idx=bos_idx,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(0., emb_dim**-0.5)))
def forward(self, word):
word_emb = self.word_embedder(word)
return word_emb
class WrapEncoder(Layer):
"""
embedder + encoder
"""
def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, word_embedder):
super(WrapEncoder, self).__init__()
self.emb_dropout = prepostprocess_dropout
self.emb_dim = d_model
self.word_embedder = word_embedder
self.pos_encoder = Embedding(
size=[max_length, self.emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
position_encoding_init(max_length, self.emb_dim)),
trainable=False))
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
def forward(self, src_word, src_pos, src_slf_attn_bias):
word_emb = self.word_embedder(src_word)
word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
pos_enc = self.pos_encoder(src_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
enc_input = layers.dropout(emb,
dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output
class DecoderLayer(Layer):
"""
decoder
"""
def __init__(self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(DecoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self,
dec_input,
enc_output,
self_attn_bias,
cross_attn_bias,
cache=None):
self_attn_output = self.self_attn(self.preprocesser1(dec_input), None,
None, self_attn_bias, cache)
self_attn_output = self.postprocesser1(self_attn_output, dec_input)
cross_attn_output = self.cross_attn(
self.preprocesser2(self_attn_output), enc_output, enc_output,
cross_attn_bias, cache)
cross_attn_output = self.postprocesser2(cross_attn_output,
self_attn_output)
ffn_output = self.ffn(self.preprocesser3(cross_attn_output))
ffn_output = self.postprocesser3(ffn_output, cross_attn_output)
return ffn_output
class Decoder(Layer):
"""
decoder
"""
def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd):
super(Decoder, self).__init__()
self.decoder_layers = list()
for i in range(n_layer):
self.decoder_layers.append(
self.add_sublayer(
"layer_%d" % i,
DecoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
def forward(self,
dec_input,
enc_output,
self_attn_bias,
cross_attn_bias,
caches=None):
for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias,
None if caches is None else caches[i])
dec_input = dec_output
return self.processer(dec_output)
class WrapDecoder(Layer):
"""
embedder + decoder
"""
def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, share_input_output_embed, word_embedder):
super(WrapDecoder, self).__init__()
self.emb_dropout = prepostprocess_dropout
self.emb_dim = d_model
self.word_embedder = word_embedder
self.pos_encoder = Embedding(
size=[max_length, self.emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
position_encoding_init(max_length, self.emb_dim)),
trainable=False))
self.decoder = Decoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
if share_input_output_embed:
self.linear = lambda x: layers.matmul(x=x,
y=self.word_embedder.
word_embedder.weight,
transpose_y=True)
else:
self.linear = Linear(input_dim=d_model,
output_dim=trg_vocab_size,
bias_attr=False)
def forward(self,
trg_word,
trg_pos,
trg_slf_attn_bias,
trg_src_attn_bias,
enc_output,
caches=None):
word_emb = self.word_embedder(trg_word)
word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
pos_enc = self.pos_encoder(trg_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
dec_input = layers.dropout(emb,
dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches)
dec_output = layers.reshape(
dec_output,
shape=[-1, dec_output.shape[-1]],
)
logits = self.linear(dec_output)
return logits
# class CrossEntropyCriterion(object):
# def __init__(self, label_smooth_eps):
# self.label_smooth_eps = label_smooth_eps
# def __call__(self, predict, label, weights):
# if self.label_smooth_eps:
# label_out = layers.label_smooth(label=layers.one_hot(
# input=label, depth=predict.shape[-1]),
# epsilon=self.label_smooth_eps)
# cost = layers.softmax_with_cross_entropy(
# logits=predict,
# label=label_out,
# soft_label=True if self.label_smooth_eps else False)
# weighted_cost = cost * weights
# sum_cost = layers.reduce_sum(weighted_cost)
# token_num = layers.reduce_sum(weights)
# token_num.stop_gradient = True
# avg_cost = sum_cost / token_num
# return sum_cost, avg_cost, token_num
class CrossEntropyCriterion(Loss):
def __init__(self, label_smooth_eps):
super(CrossEntropyCriterion, self).__init__()
self.label_smooth_eps = label_smooth_eps
def forward(self, outputs, labels):
predict = outputs[0]
label, weights = labels
if self.label_smooth_eps:
label = layers.label_smooth(label=layers.one_hot(
input=label, depth=predict.shape[-1]),
epsilon=self.label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if self.label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return avg_cost
def infer_shape(self, _):
return [[None, 1], [None, 1]]
def infer_dtype(self, _):
return ["int64", "float32"]
class Transformer(Model):
"""
model
"""
def __init__(self,
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_id=0,
eos_id=1):
super(Transformer, self).__init__()
src_word_embedder = Embedder(vocab_size=src_vocab_size,
emb_dim=d_model,
bos_idx=bos_id)
self.encoder = WrapEncoder(src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd, src_word_embedder)
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
trg_word_embedder = src_word_embedder
else:
trg_word_embedder = Embedder(vocab_size=trg_vocab_size,
emb_dim=d_model,
bos_idx=bos_id)
self.decoder = WrapDecoder(trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing,
trg_word_embedder)
self.trg_vocab_size = trg_vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_word=[None, None],
trg_pos=[None, None],
trg_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
predict = self.decoder(trg_word, trg_pos, trg_slf_attn_bias,
trg_src_attn_bias, enc_output)
return predict
def beam_search_v2(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=None,
alpha=0.6):
"""
Beam search with the alive and finished two queues, both have a beam size
capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
steps.
1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
EOS.
2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
of next decoding step.
3. `grow_finish` compares the already finished candidates in the finished queue
and newly added finished candidates from `grow_topk`, and selects the top
`beam_size` finished candidates.
"""
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
tile_dims = [1] * len(tensor.shape)
tile_dims[1] = beam_size
return layers.expand(tensor, tile_dims)
def merge_beam_dim(tensor):
return layers.reshape(tensor, [-1] + tensor.shape[2:])
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number
inf = float(1. * 1e7)
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
### initialize states of beam search ###
## init for the alive ##
initial_log_probs = to_variable(
np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
alive_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
## init for the finished ##
finished_scores = to_variable(
np.array([[-inf] * beam_size], dtype="float32"))
finished_scores = layers.expand(finished_scores, [batch_size, 1])
finished_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
finished_flags = layers.zeros_like(finished_scores)
### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1])
trg_src_attn_bias = merge_beam_dim(
expand_to_beam_size(trg_src_attn_bias, beam_size))
enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
def update_states(caches, beam_idx, beam_size):
for cache in caches:
cache["k"] = gather_2d_by_gather(cache["k"], beam_idx,
beam_size, batch_size, False)
cache["v"] = gather_2d_by_gather(cache["v"], beam_idx,
beam_size, batch_size, False)
return caches
def gather_2d_by_gather(tensor_nd,
beam_idx,
beam_size,
batch_size,
need_flat=True):
batch_idx = layers.range(0, batch_size, 1,
dtype="int64") * beam_size
flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd
idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0),
[-1])
new_flat_tensor = layers.gather(flat_tensor, idx)
new_tensor_nd = layers.reshape(
new_flat_tensor,
shape=[batch_size, beam_idx.shape[1]] +
tensor_nd.shape[2:]) if need_flat else new_flat_tensor
return new_tensor_nd
def early_finish(alive_log_probs, finished_scores,
finished_in_finished):
max_length_penalty = np.power(((5. + max_len) / 6.), alpha)
# The best possible score of the most likely alive sequence
lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty
# Now to compute the lowest score of a finished sequence in finished
# If the sequence isn't finished, we multiply it's score by 0. since
# scores are all -ve, taking the min will give us the score of the lowest
# finished item.
lowest_score_of_fininshed_in_finished = layers.reduce_min(
finished_scores * finished_in_finished, 1)
# If none of the sequences have finished, then the min will be 0 and
# we have to replace it by -ve INF if it is. The score of any seq in alive
# will be much higher than -ve INF and the termination condition will not
# be met.
lowest_score_of_fininshed_in_finished += (
1. - layers.reduce_max(finished_in_finished, 1)) * -inf
bound_is_met = layers.reduce_all(
layers.greater_than(lowest_score_of_fininshed_in_finished,
lower_bound_alive_scores))
return bound_is_met
def grow_topk(i, logits, alive_seq, alive_log_probs, states):
logits = layers.reshape(logits, [batch_size, beam_size, -1])
candidate_log_probs = layers.log(layers.softmax(logits, axis=2))
log_probs = layers.elementwise_add(candidate_log_probs,
alive_log_probs, 0)
length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha)
curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1])
topk_scores, topk_ids = layers.topk(flat_curr_scores,
k=beam_size * 2)
topk_log_probs = topk_scores * length_penalty
topk_beam_index = topk_ids // self.trg_vocab_size
topk_ids = topk_ids % self.trg_vocab_size
# use gather as gather_nd, TODO: use gather_nd
topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index,
beam_size, batch_size)
topk_seq = layers.concat(
[topk_seq,
layers.reshape(topk_ids, topk_ids.shape + [1])],
axis=2)
states = update_states(states, topk_beam_index, beam_size)
eos = layers.fill_constant(shape=topk_ids.shape,
dtype="int64",
value=eos_id)
topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32")
#topk_seq: [batch_size, 2*beam_size, i+1]
#topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
return topk_seq, topk_log_probs, topk_scores, topk_finished, states
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
states):
curr_scores += curr_finished * -inf
_, topk_indexes = layers.topk(curr_scores, k=beam_size)
alive_seq = gather_2d_by_gather(curr_seq, topk_indexes,
beam_size * 2, batch_size)
alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes,
beam_size * 2, batch_size)
states = update_states(states, topk_indexes, beam_size * 2)
return alive_seq, alive_log_probs, states
def grow_finished(finished_seq, finished_scores, finished_flags,
curr_seq, curr_scores, curr_finished):
# finished scores
finished_seq = layers.concat([
finished_seq,
layers.fill_constant(shape=[batch_size, beam_size, 1],
dtype="int64",
value=eos_id)
],
axis=2)
# Set the scores of the unfinished seq in curr_seq to large negative
# values
curr_scores += (1. - curr_finished) * -inf
# concatenating the sequences and scores along beam axis
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat([finished_scores, curr_scores],
axis=1)
curr_finished_flags = layers.concat([finished_flags, curr_finished],
axis=1)
_, topk_indexes = layers.topk(curr_finished_scores, k=beam_size)
finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes,
beam_size * 3, batch_size)
finished_scores = gather_2d_by_gather(curr_finished_scores,
topk_indexes, beam_size * 3,
batch_size)
finished_flags = gather_2d_by_gather(curr_finished_flags,
topk_indexes, beam_size * 3,
batch_size)
return finished_seq, finished_scores, finished_flags
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
i, logits, alive_seq, alive_log_probs, caches)
alive_seq, alive_log_probs, states = grow_alive(
topk_seq, topk_scores, topk_log_probs, topk_finished, states)
finished_seq, finished_scores, finished_flags = grow_finished(
finished_seq, finished_scores, finished_flags, topk_seq,
topk_scores, topk_finished)
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1])
if early_finish(alive_log_probs, finished_scores,
finished_flags).numpy():
break
return finished_seq, finished_scores
def beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
if beam_size == 1:
return self._greedy_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
max_len=max_len)
else:
return self._beam_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_len=max_len)
def _beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
tile_dims = [1] * len(tensor.shape)
tile_dims[1] = beam_size
return layers.expand(tensor, tile_dims)
def merge_batch_beams(tensor):
return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] +
tensor.shape[2:])
def split_batch_beams(tensor):
return layers.reshape(tensor,
shape=[-1, beam_size] +
list(tensor.shape[1:]))
def mask_probs(probs, finished, noend_mask_tensor):
# TODO: use where_op
finished = layers.cast(finished, dtype=probs.dtype)
probs = layers.elementwise_mul(layers.expand(
layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
noend_mask_tensor,
axis=-1) - layers.elementwise_mul(
probs, (finished - 1), axis=0)
return probs
def gather(x, indices, batch_pos):
topk_coordinates = layers.stack([batch_pos, indices], axis=2)
return layers.gather_nd(x, topk_coordinates)
def update_states(func, caches):
for cache in caches: # no need to update static_kv
cache["k"] = func(cache["k"])
cache["v"] = func(cache["v"])
return caches
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number
inf = float(1. * 1e7)
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
vocab_size_tensor = layers.fill_constant(shape=[1],
dtype="int64",
value=self.trg_vocab_size)
end_token_tensor = to_variable(
np.full([batch_size, beam_size], eos_id, dtype="int64"))
noend_array = [-inf] * self.trg_vocab_size
noend_array[eos_id] = 0
noend_mask_tensor = to_variable(np.array(noend_array,dtype="float32"))
batch_pos = layers.expand(
layers.unsqueeze(
to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]),
[1, beam_size])
predict_ids = []
parent_ids = []
### initialize states of beam search ###
log_probs = to_variable(
np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size,
dtype="float32"))
finished = to_variable(np.full([batch_size, beam_size], 0,
dtype="bool"))
### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1],
dtype="int64",
value=bos_id)
trg_pos = layers.zeros_like(trg_word)
trg_src_attn_bias = merge_batch_beams(
expand_to_beam_size(trg_src_attn_bias, beam_size))
enc_output = merge_batch_beams(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
caches = update_states( # can not be reshaped since the 0 size
lambda x: x if i == 0 else merge_batch_beams(x), caches)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
caches = update_states(split_batch_beams, caches)
step_log_probs = split_batch_beams(
layers.log(layers.softmax(logits)))
step_log_probs = mask_probs(step_log_probs, finished,
noend_mask_tensor)
log_probs = layers.elementwise_add(x=step_log_probs,
y=log_probs,
axis=0)
log_probs = layers.reshape(log_probs,
[-1, beam_size * self.trg_vocab_size])
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=beam_size)
beam_indices = layers.elementwise_floordiv(
topk_indices, vocab_size_tensor)
token_indices = layers.elementwise_mod(
topk_indices, vocab_size_tensor)
# update states
caches = update_states(lambda x: gather(x, beam_indices, batch_pos),
caches)
log_probs = gather(log_probs, topk_indices, batch_pos)
finished = gather(finished, beam_indices, batch_pos)
finished = layers.logical_or(
finished, layers.equal(token_indices, end_token_tensor))
trg_word = layers.reshape(token_indices, [-1, 1])
predict_ids.append(token_indices)
parent_ids.append(beam_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
parent_ids = layers.stack(parent_ids, axis=0)
finished_seq = layers.transpose(
layers.gather_tree(predict_ids, parent_ids), [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
def _greedy_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
max_len=256):
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
end_token_tensor = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=eos_id)
predict_ids = []
log_probs = layers.fill_constant(shape=[batch_size, 1],
dtype="float32",
value=0)
trg_word = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=bos_id)
finished = layers.fill_constant(shape=[batch_size, 1],
dtype="bool",
value=0)
## init states (caches) for transformer
caches = [{
"k":
layers.fill_constant(
shape=[batch_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
step_log_probs = layers.log(layers.softmax(logits))
log_probs = layers.elementwise_add(x=step_log_probs,
y=log_probs,
axis=0)
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=1)
finished = layers.logical_or(
finished, layers.equal(topk_indices, end_token_tensor))
trg_word = topk_indices
log_probs = topk_scores
predict_ids.append(topk_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
finished_seq = layers.transpose(predict_ids, [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
# used for continuous evaluation
enable_ce: False
eager_run: False
num_devices: 1
# The frequency to save trained models when training.
save_step: 10000
# The frequency to fetch and print output when training.
print_step: 100
# path of the checkpoint, to resume the previous training
init_from_checkpoint: ""
# path of the pretrain model, to better solve the current task
init_from_pretrain_model: ""
# path of trained parameter, to make prediction
init_from_params: "trained_params/step_100000/"
# the directory for saving model
save_model: "trained_models"
# the directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# whether to use cuda
use_cuda: True
# args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
# Hyparams for training:
# the number of epoches for training
epoch: 30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate: 2.0
beta1: 0.9
beta2: 0.997
eps: 1e-9
# the parameters for learning rate scheduling.
warmup_steps: 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps: 0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size: 5
max_out_len: 256
# the number of decoded sentences to output.
n_best: 1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size: 10000
# size of target word dictionay
trg_vocab_size: 10000
# index for <bos> token
bos_idx: 0
# index for <eos> token
eos_idx: 1
# index for <unk> token
unk_idx: 2
# max length of sequences deciding the size of position encoding table.
max_length: 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model: 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid: 2048
# the dimension that keys are projected to for dot-product attention.
d_key: 64
# the dimension that values are projected to for dot-product attention.
d_value: 64
# number of head used in multi-head attention.
n_head: 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer: 6
# dropout rates of different modules.
prepostprocess_dropout: 0.1
attention_dropout: 0.1
relu_dropout: 0.1
# to process before each sub-layer
preprocess_cmd: "n" # layer normalization
# to process after each sub-layer
postprocess_cmd: "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing: True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import argparse
import json
import yaml
import six
import logging
logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class JsonConfig(object):
"""
A high-level api for handling json configure file.
"""
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
def _parse(self, config_path):
try:
with open(config_path) as json_file:
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
config_path)
else:
return config_dict
def __getitem__(self, key):
return self._config_dict[key]
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
"""
A high-level api for handling argument configs.
"""
def __init__(self):
parser = argparse.ArgumentParser()
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5,
"Learning rate used to train with warmup.")
train_g.add_arg(
"lr_scheduler",
str,
"linear_warmup_decay",
"scheduler of learning rate.",
choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01,
"Weight decay rate for L2 regularizer.")
train_g.add_arg(
"warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g.add_arg("save_steps", int, 1000,
"The steps interval to save checkpoints.")
train_g.add_arg("use_fp16", bool, False,
"Whether to use fp16 mixed precision training.")
train_g.add_arg(
"loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g.add_arg("pred_dir", str, None,
"Path to save the prediction results")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10,
"The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True,
"If set, use GPU for training.")
run_type_g.add_arg(
"use_fast_executor", bool, False,
"If set, use fast parallel executor (in experiment).")
run_type_g.add_arg(
"num_iteration_per_drop_scope", int, 1,
"Ihe iteration intervals to clean up temporary variables.")
run_type_g.add_arg("do_train", bool, True,
"Whether to perform training.")
run_type_g.add_arg("do_predict", bool, True,
"Whether to perform prediction.")
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
def print_arguments(args, log=None):
if not log:
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
else:
log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class PDConfig(object):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def __init__(self, json_file="", yaml_file="", fuse_args=True):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert isinstance(json_file, str)
assert isinstance(yaml_file, str)
if json_file != "" and yaml_file != "":
raise Warning(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self.args = None
self.arg_config = {}
self.json_config = {}
self.yaml_config = {}
parser = argparse.ArgumentParser()
self.default_g = ArgumentGroup(parser, "default", "default options.")
self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.")
self.json_g = ArgumentGroup(parser, "json", "options from json.")
self.com_g = ArgumentGroup(parser, "custom", "customized options.")
self.default_g.add_arg("do_train", bool, False,
"Whether to perform training.")
self.default_g.add_arg("do_predict", bool, False,
"Whether to perform predicting.")
self.default_g.add_arg("do_eval", bool, False,
"Whether to perform evaluating.")
self.default_g.add_arg("do_save_inference_model", bool, False,
"Whether to perform model saving for inference.")
# NOTE: args for profiler
self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)")
self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)")
self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)")
self.parser = parser
if json_file != "":
self.load_json(json_file, fuse_args=fuse_args)
if yaml_file:
self.load_yaml(yaml_file, fuse_args=fuse_args)
def load_json(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.json_config = json.loads(fin.read())
fin.close()
if fuse_args:
for name in self.json_config:
if isinstance(self.json_config[name], list):
self.json_g.add_arg(
name,
type(self.json_config[name][0]),
self.json_config[name],
"This is from %s" % file_path,
nargs=len(self.json_config[name]))
continue
if not isinstance(self.json_config[name], int) \
and not isinstance(self.json_config[name], float) \
and not isinstance(self.json_config[name], str) \
and not isinstance(self.json_config[name], bool):
continue
self.json_g.add_arg(name,
type(self.json_config[name]),
self.json_config[name],
"This is from %s" % file_path)
def load_yaml(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
if fuse_args:
for name in self.yaml_config:
if isinstance(self.yaml_config[name], list):
self.yaml_g.add_arg(
name,
type(self.yaml_config[name][0]),
self.yaml_config[name],
"This is from %s" % file_path,
nargs=len(self.yaml_config[name]))
continue
if not isinstance(self.yaml_config[name], int) \
and not isinstance(self.yaml_config[name], float) \
and not isinstance(self.yaml_config[name], str) \
and not isinstance(self.yaml_config[name], bool):
continue
self.yaml_g.add_arg(name,
type(self.yaml_config[name]),
self.yaml_config[name],
"This is from %s" % file_path)
def build(self):
self.args = self.parser.parse_args()
self.arg_config = vars(self.args)
def __add__(self, new_arg):
assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
assert len(new_arg) >= 3
assert self.args is None
name = new_arg[0]
dtype = new_arg[1]
dvalue = new_arg[2]
desc = new_arg[3] if len(
new_arg) == 4 else "Description is not provided."
self.com_g.add_arg(name, dtype, dvalue, desc)
return self
def __getattr__(self, name):
if name in self.arg_config:
return self.arg_config[name]
if name in self.json_config:
return self.json_config[name]
if name in self.yaml_config:
return self.yaml_config[name]
raise Warning("The argument %s is not defined." % name)
def Print(self):
print("-" * 70)
for name in self.arg_config:
print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name])))
for name in self.json_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.json_config[name])))
for name in self.yaml_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.yaml_config[name])))
print("-" * 70)
if __name__ == "__main__":
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
pd_config = PDConfig(yaml_file="./test/bert_config.yaml")
pd_config += ("my_age", int, 18, "I am forever 18.")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
print(pd_config.my_age)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册