未验证 提交 4196e7dd 编写于 作者: P pkpk 提交者: GitHub

Merge pull request #19 from guoshengCS/hapi-transformer

Add Transformer implementation and text.py
此差异已折叠。
## Transformer
以下是本例的简要目录结构及说明:
```text
.
├── images # README 文档中的图片
├── utils # 工具包
├── gen_data.sh # 数据生成脚本
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
├── model.py # 模型定义文件
└── transformer.yaml # 配置文件
```
## 模型简介
机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。
本项目是机器翻译领域主流模型 Transformer 的 PaddlePaddle 实现, 包含模型训练,预测以及使用自定义数据等内容。用户可以基于发布的内容搭建自己的翻译模型。
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 下载代码
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/dygraph/transformer
```
3. 环境依赖
请参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.6/beginners_guide/install/index_cn.html)部分的内容
### 数据准备
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例提供。运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理(时间较长,建议后台运行)。数据处理过程主要包括 Tokenize 和 [BPE 编码(byte-pair encoding)](https://arxiv.org/pdf/1508.07909)。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下:
```text
.
├── wmt16_ende_data # WMT16 英德翻译数据
├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
└── subword-nmt # BPE 编码的代码
```
另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用,其中包含词典(`vocab_all.bpe.32000`文件)、训练所需的 BPE 数据(`train.tok.clean.bpe.32000.en-de`文件)、预测所需的 BPE 数据(`newstest2016.tok.bpe.32000.en-de`等文件)和相应的评估预测结果所需的 tokenize 数据(`newstest2016.tok.de`等文件)。
自定义数据:如果需要使用自定义数据,本项目程序中可直接支持的数据格式为制表符 \t 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔。提供以上格式的数据文件(可以分多个part,数据读取支持文件通配符)和相应的词典文件即可直接运行。
### 单机训练
### 单机单卡
以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096
```
以上命令中传入了训练轮数(`epoch`)和训练数据文件路径(注意请正确设置,支持通配符)等参数,更多参数的使用以及支持的模型超参数可以参见 `transformer.yaml` 配置文件,其中默认提供了 Transformer base model 的配置,如需调整可以在配置文件中更改或通过命令行传入(命令行传入内容将覆盖配置文件中的设置)。可以通过以下命令来训练 Transformer 论文中的 big model:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `transformer.pdparams``transformer.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
```txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
```
也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。
#### 单机多卡
Paddle动态图支持多进程多卡进行模型训练,启动训练的方式如下:
```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--use_cuda True \
--save_step 10000
```
此时,程序会将每个进程的输出log导入到`./mylog`路径下,只有第一个工作进程会保存模型。
```
.
├── mylog
│   ├── workerlog.0
│   ├── workerlog.1
│   ├── workerlog.2
│   ├── workerlog.3
│   ├── workerlog.4
│   ├── workerlog.5
│   ├── workerlog.6
│   └── workerlog.7
```
### 模型推断
以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt
```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
### 模型评估
预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标):
```sh
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果:
```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
```
使用本项目中提供的内容,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试有大约如下的 BLEU 值:
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| Base | 26.35 | 29.07 | 33.30 |
| Big | 27.07 | 30.09 | 34.38 |
### 预训练模型
我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_dygraph.tar.gz)[big model](https://transformer-res.bj.bcebos.com/big_model_dygraph.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。
## 进阶使用
### 背景介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构,其完全使用注意力(Attention)机制来实现序列到序列的建模[1]。
相较于此前 Seq2Seq 模型中广泛使用的循环神经网络(Recurrent Neural Network, RNN),使用(Self)Attention 进行输入序列到输出序列的变换主要具有以下优势:
- 计算复杂度小
- 特征维度为 d 、长度为 n 的序列,在 RNN 中计算复杂度为 `O(n * d * d)` (n 个时间步,每个时间步计算 d 维的矩阵向量乘法),在 Self-Attention 中计算复杂度为 `O(n * n * d)` (n 个时间步两两计算 d 维的向量点积或其他相关度函数),n 通常要小于 d 。
- 计算并行度高
- RNN 中当前时间步的计算要依赖前一个时间步的计算结果;Self-Attention 中各时间步的计算只依赖输入不依赖之前时间步输出,各时间步可以完全并行。
- 容易学习长程依赖(long-range dependencies)
- RNN 中相距为 n 的两个位置间的关联需要 n 步才能建立;Self-Attention 中任何两个位置都直接相连;路径越短信号传播越容易。
Transformer 中引入使用的基于 Self-Attention 的序列建模模块结构,已被广泛应用在 Bert [2]等语义表示模型中,取得了显著效果。
### 模型概览
Transformer 同样使用了 Seq2Seq 模型中典型的编码器-解码器(Encoder-Decoder)的框架结构,整体网络结构如图1所示。
<p align="center">
<img src="images/transformer_network.png" height=400 hspace='10'/> <br />
图 1. Transformer 网络结构图
</p>
可以看到,和以往 Seq2Seq 模型不同,Transformer 的 Encoder 和 Decoder 中不再使用 RNN 的结构。
### 模型特点
Transformer 中的 Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
此外,每个 sub-layer 后还施以 Residual Connection [3]和 Layer Normalization [4]来促进梯度传播和模型收敛。
<p align="center">
<img src="images/multi_head_attention.png" height=300 hspace='10'/> <br />
图 2. Multi-Head Attention
</p>
Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
## FAQ
**Q:** 预测结果中样本数少于输入的样本数是什么原因
**A:** 若样本中最大长度超过 `transformer.yaml``max_length` 的默认设置,请注意运行时增大 `--max_length` 的设置,否则超长样本将被过滤。
**Q:** 预测时最大长度超过了训练时的最大长度怎么办
**A:** 由于训练时 `max_length` 的设置决定了保存模型 position encoding 的大小,若预测时长度超过 `max_length`,请调大该值,会重新生成更大的 position encoding 表。
## 参考文献
1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
2. Devlin J, Chang M W, Lee K, et al. [Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805)[J]. arXiv preprint arXiv:1810.04805, 2018.
3. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
4. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
5. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
## 作者
- [guochengCS](https://github.com/guoshengCS)
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.fluid.layers.utils import flatten
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from model import Input, set_device
from reader import prepare_infer_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import InferTransformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
# define data
dataset = Seq2SeqDataset(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = Seq2SeqDataset.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None
if fluid.in_dygraph_mode() else [x.forward() for x in inputs],
collate_fn=partial(
prepare_infer_input, src_pad_idx=args.eos_idx, n_head=args.n_head),
num_workers=0,
return_list=True)
# define model
transformer = InferTransformer(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
transformer.load(os.path.join(args.init_from_params, "transformer"))
# TODO: use model.predict when support variant length
f = open(args.output_file, "wb")
for data in data_loader():
finished_seq = transformer.test(inputs=flatten(data))[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_predict(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import six
import os
import tarfile
import itertools
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
return data_inputs
def prepare_infer_input(insts, src_pad_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
return data_inputs
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones(
(inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class Seq2SeqDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
tar_fname=None,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
only_src=False):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname)
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip(b"\n")] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class Seq2SeqBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
seed=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos,
key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size *
self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
# for multi-device
for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank:
batch_indices = [info.i for info in batch]
yield batch_indices
if self._local_rank > len(batches) % self._nranks:
yield batch_indices
def __len__(self):
return 100
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.io import DataLoader
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from model import Input, set_device
from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
class LoggerCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(LoggerCallback, self).__init__(log_freq, verbose)
# TODO: wrap these override function to simplify
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
super(LoggerCallback, self).on_train_begin(logs)
self.train_metrics += ["normalized loss", "ppl"]
def on_train_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(LoggerCallback, self).on_eval_begin(logs)
self.eval_metrics += ["normalized loss", "ppl"]
def on_eval_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_eval_batch_end(step, logs)
def do_train(args):
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
# set seed for CE
random_seed = eval(str(args.random_seed))
if random_seed is not None:
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define inputs
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, None], "int64", name="trg_word"),
Input(
[None, None], "int64", name="trg_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
labels = [
Input(
[None, 1], "int64", name="label"),
Input(
[None, 1], "float32", name="weight"),
]
# def dataloader
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.bos_idx, args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(args.d_model,
args.warmup_steps),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
transformer.load(
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train
transformer.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=1,
eval_freq=1,
save_freq=1,
verbose=2,
callbacks=[
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_train(args)
此差异已折叠。
# used for continuous evaluation
enable_ce: False
eager_run: False
# The frequency to save trained models when training.
save_step: 10000
# The frequency to fetch and print output when training.
print_step: 100
# path of the checkpoint, to resume the previous training
init_from_checkpoint: ""
# path of the pretrain model, to better solve the current task
init_from_pretrain_model: ""
# path of trained parameter, to make prediction
init_from_params: "trained_params/step_100000/"
# the directory for saving model
save_model: "trained_models"
# the directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# whether to use cuda
use_cuda: True
# args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
# Hyparams for training:
# the number of epoches for training
epoch: 30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate: 2.0
beta1: 0.9
beta2: 0.997
eps: 1e-9
# the parameters for learning rate scheduling.
warmup_steps: 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps: 0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size: 5
max_out_len: 256
# the number of decoded sentences to output.
n_best: 1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size: 10000
# size of target word dictionay
trg_vocab_size: 10000
# index for <bos> token
bos_idx: 0
# index for <eos> token
eos_idx: 1
# index for <unk> token
unk_idx: 2
# max length of sequences deciding the size of position encoding table.
max_length: 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model: 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid: 2048
# the dimension that keys are projected to for dot-product attention.
d_key: 64
# the dimension that values are projected to for dot-product attention.
d_value: 64
# number of head used in multi-head attention.
n_head: 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer: 6
# dropout rates of different modules.
prepostprocess_dropout: 0.1
attention_dropout: 0.1
relu_dropout: 0.1
# to process before each sub-layer
preprocess_cmd: "n" # layer normalization
# to process after each sub-layer
postprocess_cmd: "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing: True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import argparse
import json
import yaml
import six
import logging
logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class JsonConfig(object):
"""
A high-level api for handling json configure file.
"""
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
def _parse(self, config_path):
try:
with open(config_path) as json_file:
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
config_path)
else:
return config_dict
def __getitem__(self, key):
return self._config_dict[key]
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
"""
A high-level api for handling argument configs.
"""
def __init__(self):
parser = argparse.ArgumentParser()
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5,
"Learning rate used to train with warmup.")
train_g.add_arg(
"lr_scheduler",
str,
"linear_warmup_decay",
"scheduler of learning rate.",
choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01,
"Weight decay rate for L2 regularizer.")
train_g.add_arg(
"warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g.add_arg("save_steps", int, 1000,
"The steps interval to save checkpoints.")
train_g.add_arg("use_fp16", bool, False,
"Whether to use fp16 mixed precision training.")
train_g.add_arg(
"loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g.add_arg("pred_dir", str, None,
"Path to save the prediction results")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10,
"The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True,
"If set, use GPU for training.")
run_type_g.add_arg(
"use_fast_executor", bool, False,
"If set, use fast parallel executor (in experiment).")
run_type_g.add_arg(
"num_iteration_per_drop_scope", int, 1,
"Ihe iteration intervals to clean up temporary variables.")
run_type_g.add_arg("do_train", bool, True,
"Whether to perform training.")
run_type_g.add_arg("do_predict", bool, True,
"Whether to perform prediction.")
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
def print_arguments(args, log=None):
if not log:
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
else:
log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class PDConfig(object):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def __init__(self, json_file="", yaml_file="", fuse_args=True):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert isinstance(json_file, str)
assert isinstance(yaml_file, str)
if json_file != "" and yaml_file != "":
raise Warning(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self.args = None
self.arg_config = {}
self.json_config = {}
self.yaml_config = {}
parser = argparse.ArgumentParser()
self.default_g = ArgumentGroup(parser, "default", "default options.")
self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.")
self.json_g = ArgumentGroup(parser, "json", "options from json.")
self.com_g = ArgumentGroup(parser, "custom", "customized options.")
self.default_g.add_arg("do_train", bool, False,
"Whether to perform training.")
self.default_g.add_arg("do_predict", bool, False,
"Whether to perform predicting.")
self.default_g.add_arg("do_eval", bool, False,
"Whether to perform evaluating.")
self.default_g.add_arg("do_save_inference_model", bool, False,
"Whether to perform model saving for inference.")
# NOTE: args for profiler
self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)")
self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)")
self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)")
self.parser = parser
if json_file != "":
self.load_json(json_file, fuse_args=fuse_args)
if yaml_file:
self.load_yaml(yaml_file, fuse_args=fuse_args)
def load_json(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.json_config = json.loads(fin.read())
fin.close()
if fuse_args:
for name in self.json_config:
if isinstance(self.json_config[name], list):
self.json_g.add_arg(
name,
type(self.json_config[name][0]),
self.json_config[name],
"This is from %s" % file_path,
nargs=len(self.json_config[name]))
continue
if not isinstance(self.json_config[name], int) \
and not isinstance(self.json_config[name], float) \
and not isinstance(self.json_config[name], str) \
and not isinstance(self.json_config[name], bool):
continue
self.json_g.add_arg(name,
type(self.json_config[name]),
self.json_config[name],
"This is from %s" % file_path)
def load_yaml(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
if fuse_args:
for name in self.yaml_config:
if isinstance(self.yaml_config[name], list):
self.yaml_g.add_arg(
name,
type(self.yaml_config[name][0]),
self.yaml_config[name],
"This is from %s" % file_path,
nargs=len(self.yaml_config[name]))
continue
if not isinstance(self.yaml_config[name], int) \
and not isinstance(self.yaml_config[name], float) \
and not isinstance(self.yaml_config[name], str) \
and not isinstance(self.yaml_config[name], bool):
continue
self.yaml_g.add_arg(name,
type(self.yaml_config[name]),
self.yaml_config[name],
"This is from %s" % file_path)
def build(self):
self.args = self.parser.parse_args()
self.arg_config = vars(self.args)
def __add__(self, new_arg):
assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
assert len(new_arg) >= 3
assert self.args is None
name = new_arg[0]
dtype = new_arg[1]
dvalue = new_arg[2]
desc = new_arg[3] if len(
new_arg) == 4 else "Description is not provided."
self.com_g.add_arg(name, dtype, dvalue, desc)
return self
def __getattr__(self, name):
if name in self.arg_config:
return self.arg_config[name]
if name in self.json_config:
return self.json_config[name]
if name in self.yaml_config:
return self.yaml_config[name]
raise Warning("The argument %s is not defined." % name)
def Print(self):
print("-" * 70)
for name in self.arg_config:
print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name])))
for name in self.json_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.json_config[name])))
for name in self.yaml_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.yaml_config[name])))
print("-" * 70)
if __name__ == "__main__":
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
pd_config = PDConfig(yaml_file="./test/bert_config.yaml")
pd_config += ("my_age", int, 18, "I am forever 18.")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
print(pd_config.my_age)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册