提交 a46365b6 编写于 作者: 1024的传说's avatar 1024的传说 提交者: pkpk

mv neural_machine_translation to PaddleMT (#3732)

上级 f911f34a
#!/bin/bash
DATA_PATH=./dataset/wmt16
train(){
python -u main.py \
--do_train True \
--src_vocab_fpath $DATA_PATH/en_10000.dict \
--trg_vocab_fpath $DATA_PATH/de_10000.dict \
--special_token '<s>' '<e>' '<unk>' \
--training_file $DATA_PATH/wmt16/train \
--use_token_batch True \
--batch_size 2048 \
--sort_type pool \
--pool_size 10000 \
--print_step 1 \
--weight_sharing False \
--epoch 20 \
--enable_ce True \
--random_seed 1000 \
--save_checkpoint "" \
--save_param ""
}
cudaid=${transformer:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid
train | python _ce.py
cudaid=${transformer_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid
train | python _ce.py
\ No newline at end of file
## Transformer
以下是本例的简要目录结构及说明:
```text
.
├── images # README 文档中的图片
├── utils # 工具包
├── desc.py # 输入描述文件
├── gen_data.sh # 数据生成脚本
├── inference_model.py # 保存 inference_model 的脚本
├── main.py # 主程序入口
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
├── transformer.py # 模型定义文件
└── transformer.yaml # 配置文件
```
## 模型简介
机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。
本项目是机器翻译领域主流模型 Transformer 的 PaddlePaddle 实现, 包含模型训练,预测以及使用自定义数据等内容。用户可以基于发布的内容搭建自己的翻译模型。
同时推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/122281)
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle 1.6及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 下载代码
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/neural_machine_translation/transformer
```
3. 环境依赖
请参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)部分的内容
### 数据准备
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例提供。运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理(时间较长,建议后台运行)。数据处理过程主要包括 Tokenize 和 [BPE 编码(byte-pair encoding)](https://arxiv.org/pdf/1508.07909)。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下:
```text
.
├── wmt16_ende_data # WMT16 英德翻译数据
├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
└── subword-nmt # BPE 编码的代码
```
另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用,其中包含词典(`vocab_all.bpe.32000`文件)、训练所需的 BPE 数据(`train.tok.clean.bpe.32000.en-de`文件)、预测所需的 BPE 数据(`newstest2016.tok.bpe.32000.en-de`等文件)和相应的评估预测结果所需的 tokenize 数据(`newstest2016.tok.de`等文件)。
自定义数据:如果需要使用自定义数据,本项目程序中可直接支持的数据格式为制表符 \t 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔。提供以上格式的数据文件(可以分多个part,数据读取支持文件通配符)和相应的词典文件即可直接运行。
### 单机训练
以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
# open garbage collection to save memory
export FLAGS_eager_delete_tensor_gb=0.0
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u main.py \
--do_train True \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--batch_size 4096
```
以上命令中传入了执行训练(`do_train`)、训练轮数(`epoch`)和训练数据文件路径(注意请正确设置,支持通配符)等参数,更多参数的使用以及支持的模型超参数可以参见 `transformer.yaml` 配置文件,其中默认提供了 Transformer base model 的配置,如需调整可以在配置文件中更改或通过命令行传入(命令行传入内容将覆盖配置文件中的设置)。可以通过以下命令来训练 Transformer 论文中的 big model:
```sh
# open garbage collection to save memory
export FLAGS_eager_delete_tensor_gb=0.0
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u main.py \
--do_train True \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度相对较慢。在执行训练时若提供了 `save_param``save_checkpoint`(默认为 trained_params 和 trained_ckpts),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将分别保存当前训练的参数值和 checkpoint 到相应目录,每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
```txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
```
### 模型推断
以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
# open garbage collection to save memory
export FLAGS_eager_delete_tensor_gb=0.0
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u main.py \
--do_predict True \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt
```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh
# open garbage collection to save memory
export FLAGS_eager_delete_tensor_gb=0.0
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u main.py \
--do_predict True \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
### 模型评估
预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标):
```sh
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果:
```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
```
使用本项目中提供的内容,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试有大约如下的 BLEU 值:
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| Base | 26.35 | 29.07 | 33.30 |
| Big | 27.07 | 30.09 | 34.38 |
### 预训练模型
我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_params.tar.gz)[big model](https://transformer-res.bj.bcebos.com/big_model_params.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。
## 进阶使用
### 背景介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构,其完全使用注意力(Attention)机制来实现序列到序列的建模[1]。
相较于此前 Seq2Seq 模型中广泛使用的循环神经网络(Recurrent Neural Network, RNN),使用(Self)Attention 进行输入序列到输出序列的变换主要具有以下优势:
- 计算复杂度小
- 特征维度为 d 、长度为 n 的序列,在 RNN 中计算复杂度为 `O(n * d * d)` (n 个时间步,每个时间步计算 d 维的矩阵向量乘法),在 Self-Attention 中计算复杂度为 `O(n * n * d)` (n 个时间步两两计算 d 维的向量点积或其他相关度函数),n 通常要小于 d 。
- 计算并行度高
- RNN 中当前时间步的计算要依赖前一个时间步的计算结果;Self-Attention 中各时间步的计算只依赖输入不依赖之前时间步输出,各时间步可以完全并行。
- 容易学习长程依赖(long-range dependencies)
- RNN 中相距为 n 的两个位置间的关联需要 n 步才能建立;Self-Attention 中任何两个位置都直接相连;路径越短信号传播越容易。
Transformer 中引入使用的基于 Self-Attention 的序列建模模块结构,已被广泛应用在 Bert [2]等语义表示模型中,取得了显著效果。
### 模型概览
Transformer 同样使用了 Seq2Seq 模型中典型的编码器-解码器(Encoder-Decoder)的框架结构,整体网络结构如图1所示。
<p align="center">
<img src="images/transformer_network.png" height=400 hspace='10'/> <br />
图 1. Transformer 网络结构图
</p>
可以看到,和以往 Seq2Seq 模型不同,Transformer 的 Encoder 和 Decoder 中不再使用 RNN 的结构。
### 模型特点
Transformer 中的 Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
此外,每个 sub-layer 后还施以 Residual Connection [3]和 Layer Normalization [4]来促进梯度传播和模型收敛。
<p align="center">
<img src="images/multi_head_attention.png" height=300 hspace='10'/> <br />
图 2. Multi-Head Attention
</p>
Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
## FAQ
**Q:** 预测结果中样本数少于输入的样本数是什么原因
**A:** 若样本中最大长度超过 `transformer.yaml``max_length` 的默认设置,请注意运行时增大 `--max_length` 的设置,否则超长样本将被过滤。
**Q:** 预测时最大长度超过了训练时的最大长度怎么办
**A:** 由于训练时 `max_length` 的设置决定了保存模型 position encoding 的大小,若预测时长度超过 `max_length`,请调大该值,会重新生成更大的 position encoding 表。
## 参考文献
1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
2. Devlin J, Chang M W, Lee K, et al. [Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805)[J]. arXiv preprint arXiv:1810.04805, 2018.
3. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
4. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
5. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
## 版本更新
2019/08/16 进行了规范化,更新了 Paddle 接口的使用
## 作者
- [guochengCS](https://github.com/guoshengCS)
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
####this file is only used for continuous evaluation test!
import os
import sys
sys.path.insert(0, os.environ['ceroot'])
from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!!
train_cost_card1_kpi = CostKpi('train_cost_card1', 0.002, 0, actived=True)
# test_cost_card1_kpi = CostKpi('test_cost_card1', 0.008, 0, actived=True)
train_duration_card1_kpi = DurationKpi(
'train_duration_card1', 0.006, 0, actived=True)
train_cost_card4_kpi = CostKpi('train_cost_card4', 0.001, 0, actived=True)
# test_cost_card4_kpi = CostKpi('test_cost_card4', 0.001, 0, actived=True)
train_duration_card4_kpi = DurationKpi(
'train_duration_card4', 0.02, 0, actived=True)
tracking_kpis = [
train_cost_card1_kpi,
# test_cost_card1_kpi,
train_duration_card1_kpi,
train_cost_card4_kpi,
# test_cost_card4_kpi,
train_duration_card4_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
print("-----%s" % fs)
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
print("*****")
print(log)
print("****")
log_to_ce(log)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = None
# The placeholder for squence length in compile time.
seq_len = None
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
d_model = 512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(None, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(None, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"init_idx",
"trg_src_attn_bias", )
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt16_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt16_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://www.statmt.org/europarl/v7/de-en.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz'
'news-commentary-v11.de-en.en' 'news-commentary-v11.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
DEV_TEST_DATA=(
'http://data.statmt.org/wmt16/translation-task/dev.tgz'
'newstest201[45]-deen-ref.en.sgm' 'newstest201[45]-deen-src.de.sgm'
'http://data.statmt.org/wmt16/translation-task/test.tgz'
'newstest2016-deen-ref.en.sgm' 'newstest2016-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2016
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l $l -threads 8 < $f > $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 80
fi
done
# Clone subword-nmt and generate BPE data
if [ ! -d ${OUTPUT_DIR}/subword-nmt ]; then
git clone https://github.com/rsennrich/subword-nmt.git ${OUTPUT_DIR}/subword-nmt
fi
# Generate BPE data and vocabulary
for num_operations in 32000; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
#include palm for easier nlp coding
from palm.toolkit.input_field import InputField
from palm.toolkit.configure import PDConfig
# include task-specific libs
import desc
import reader
from transformer import create_net
def init_from_pretrain_model(args, exe, program):
assert isinstance(args.init_from_pretrain_model, str)
if not os.path.exists(args.init_from_pretrain_model):
raise Warning("The pretrained params do not exist.")
return False
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(
os.path.join(args.init_from_pretrain_model, var.name))
fluid.io.load_vars(
exe,
args.init_from_pretrain_model,
main_program=program,
predicate=existed_params)
print("finish initing model from pretrained params from %s" %
(args.init_from_pretrain_model))
return True
def init_from_params(args, exe, program):
assert isinstance(args.init_from_params, str)
if not os.path.exists(args.init_from_params):
raise Warning("the params path does not exist.")
return False
fluid.io.load_params(
executor=exe,
dirname=args.init_from_params,
main_program=program,
filename="params.pdparams")
print("finish init model from params from %s" % (args.init_from_params))
return True
def do_save_inference_model(args):
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
test_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
input_field.build(build_pyreader=True)
# define the network
predictions = create_net(
is_training=False, model_input=input_field, args=args)
out_ids, out_scores = predictions
# This is used here to set dropout to the test mode.
test_prog = test_prog.clone(for_test=True)
# prepare predicting
## define the executor and program for training
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_params) or (args.init_from_pretrain_model)
if args.init_from_params:
init_from_params(args, exe, test_prog)
elif args.init_from_pretrain_model:
init_from_pretrain_model(args, exe, test_prog)
# saving inference model
fluid.io.save_inference_model(
args.inference_model_dir,
feeded_var_names=input_field_names,
target_vars=[out_ids, out_scores],
executor=exe,
main_program=test_prog,
model_filename="model.pdmodel",
params_filename="params.pdparams")
print("save inference model at %s" % (args.inference_model_dir))
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
do_save_inference_model(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import numpy as np
import paddle
import paddle.fluid as fluid
#include palm for easier nlp coding
from palm.toolkit.configure import PDConfig
from train import do_train
from predict import do_predict
from inference_model import do_save_inference_model
if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
logging.getLogger().setLevel(logging.INFO)
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
if args.do_train:
do_train(args)
if args.do_predict:
do_predict(args)
if args.do_save_inference_model:
do_save_inference_model(args)
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import desc
import reader
from transformer import create_net, position_encoding_init
def init_from_pretrain_model(args, exe, program):
assert isinstance(args.init_from_pretrain_model, str)
if not os.path.exists(args.init_from_pretrain_model):
raise Warning("The pretrained params do not exist.")
return False
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(
os.path.join(args.init_from_pretrain_model, var.name))
fluid.io.load_vars(
exe,
args.init_from_pretrain_model,
main_program=program,
predicate=existed_params)
print("finish initing model from pretrained params from %s" %
(args.init_from_pretrain_model))
return True
def init_from_params(args, exe, program):
assert isinstance(args.init_from_params, str)
if not os.path.exists(args.init_from_params):
raise Warning("the params path does not exist.")
return False
fluid.io.load_params(
executor=exe,
dirname=args.init_from_params,
main_program=program,
filename="params.pdparams")
print("finish init model from params from %s" % (args.init_from_params))
return True
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
# define the data generator
processor = reader.DataProcessor(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
device_count=dev_count,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict", place=place)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
test_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
input_field.build(build_pyreader=True)
# define the network
predictions = create_net(
is_training=False, model_input=input_field, args=args)
out_ids, out_scores = predictions
# This is used here to set dropout to the test mode.
test_prog = test_prog.clone(for_test=True)
# prepare predicting
## define the executor and program for training
exe = fluid.Executor(place)
exe.run(startup_prog)
assert (args.init_from_params) or (args.init_from_pretrain_model)
if args.init_from_params:
init_from_params(args, exe, test_prog)
elif args.init_from_pretrain_model:
init_from_pretrain_model(args, exe, test_prog)
# to avoid a longer length than training, reset the size of position encoding to max_length
for pos_enc_param_name in desc.pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(args.max_length + 1, args.d_model), place)
exe_strategy = fluid.ExecutionStrategy()
# to clear tensor array after each iteration
exe_strategy.num_iteration_per_drop_scope = 1
compiled_test_prog = fluid.CompiledProgram(test_prog).with_data_parallel(
exec_strategy=exe_strategy, places=place)
f = open(args.output_file, "wb")
# start predicting
## decorate the pyreader with batch_generator
input_field.loader.set_batch_generator(batch_generator)
input_field.loader.start()
while True:
try:
seq_ids, seq_scores = exe.run(
compiled_test_prog,
fetch_list=[out_ids.name, out_scores.name],
return_numpy=False)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
for i in range(len(seq_ids.lod()[0]) -
1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(b" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end], args.bos_idx,
args.eos_idx)
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
f.write(hyps[i][-1] + b"\n")
if len(hyps[i]) >= args.n_best:
break
except fluid.core.EOFException:
break
f.close()
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_predict(args)
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import utils.dist_utils as dist_utils
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import desc
import reader
from transformer import create_net, position_encoding_init
if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
# num_trainers is used for multi-process gpu training
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def init_from_pretrain_model(args, exe, program):
assert isinstance(args.init_from_pretrain_model, str)
if not os.path.exists(args.init_from_pretrain_model):
raise Warning("The pretrained params do not exist.")
return False
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(
os.path.join(args.init_from_pretrain_model, var.name))
fluid.io.load_vars(
exe,
args.init_from_pretrain_model,
main_program=program,
predicate=existed_params)
print("finish initing model from pretrained params from %s" %
(args.init_from_pretrain_model))
return True
def init_from_checkpoint(args, exe, program):
assert isinstance(args.init_from_checkpoint, str)
if not os.path.exists(args.init_from_checkpoint):
raise Warning("the checkpoint path does not exist.")
return False
fluid.io.load_persistables(
executor=exe,
dirname=args.init_from_checkpoint,
main_program=program,
filename="checkpoint.pdckpt")
print("finish initing model from checkpoint from %s" %
(args.init_from_checkpoint))
return True
def save_checkpoint(args, exe, program, dirname):
assert isinstance(args.save_model_path, str)
checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint)
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
fluid.io.save_persistables(
exe,
os.path.join(checkpoint_dir, dirname),
main_program=program,
filename="checkpoint.pdparams")
print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))
return True
def save_param(args, exe, program, dirname):
assert isinstance(args.save_model_path, str)
param_dir = os.path.join(args.save_model_path, args.save_param)
if not os.path.exists(param_dir):
os.mkdir(param_dir)
fluid.io.save_params(
exe,
os.path.join(param_dir, dirname),
main_program=program,
filename="params.pdparams")
print("save parameters at %s" % (os.path.join(param_dir, dirname)))
return True
def do_train(args):
if args.use_cuda:
if num_trainers > 1: # for multi-process gpu training
dev_count = 1
else:
dev_count = fluid.core.get_cuda_device_count()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
# define the data generator
processor = reader.DataProcessor(
fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=dev_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if num_trainers > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
train_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program()
random_seed = eval(str(args.random_seed))
if random_seed is not None:
train_prog.random_seed = random_seed
startup_prog.random_seed = random_seed
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
# define input and reader
input_field_names = desc.encoder_data_input_fields + \
desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
input_field.build(build_pyreader=True)
# define the network
sum_cost, avg_cost, token_num = create_net(
is_training=True, model_input=input_field, args=args)
# define the optimizer
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
args.d_model, args.warmup_steps) * args.learning_rate
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps))
optimizer.minimize(avg_cost)
# prepare training
## decorate the pyreader with batch_generator
input_field.loader.set_batch_generator(batch_generator)
## define the executor and program for training
exe = fluid.Executor(place)
exe.run(startup_prog)
# init position_encoding
for pos_enc_param_name in desc.pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(args.max_length + 1, args.d_model), place)
assert (args.init_from_checkpoint == "") or (
args.init_from_pretrain_model == "")
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
init_from_checkpoint(args, exe, train_prog)
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
init_from_pretrain_model(args, exe, train_prog)
build_strategy = fluid.compiler.BuildStrategy()
build_strategy.enable_inplace = True
exec_strategy = fluid.ExecutionStrategy()
if num_trainers > 1:
dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
exec_strategy.num_threads = 1
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# start training
step_idx = 0
for pass_id in range(args.epoch):
pass_start_time = time.time()
input_field.loader.start()
batch_id = 0
while True:
try:
outs = exe.run(compiled_train_prog,
fetch_list=[sum_cost.name, token_num.name]
if step_idx % args.print_step == 0 else [])
if step_idx % args.print_step == 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array(
outs[1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
if args.save_checkpoint:
save_checkpoint(args, exe, train_prog,
"step_" + str(step_idx))
if args.save_param:
save_param(args, exe, train_prog,
"step_" + str(step_idx))
batch_id += 1
step_idx += 1
except fluid.core.EOFException:
input_field.loader.reset()
break
time_consumed = time.time() - pass_start_time
if args.save_checkpoint:
save_checkpoint(args, exe, train_prog, "step_final")
if args.save_param:
save_param(args, exe, train_prog, "step_final")
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_train(args)
此差异已折叠。
# used for continuous evaluation
enable_ce: False
# The frequency to save trained models when training.
save_step: 10000
# The frequency to fetch and print output when training.
print_step: 100
# path of the checkpoint, to resume the previous training
init_from_checkpoint: ""
# path of the pretrain model, to better solve the current task
init_from_pretrain_model: ""
# path of trained parameter, to make prediction
init_from_params: "trained_params/step_100000"
save_model_path: ""
# the directory for saving checkpoints.
save_checkpoint: "trained_ckpts"
# the directory for saving trained parameters.
save_param: "trained_params"
# the directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# whether to use cuda
use_cuda: True
# args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
# Hyparams for training:
# the number of epoches for training
epoch: 30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate: 2.0
beta1: 0.9
beta2: 0.997
eps: 1e-9
# the parameters for learning rate scheduling.
warmup_steps: 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps: 0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size: 5
max_out_len: 256
# the number of decoded sentences to output.
n_best: 1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size: 10000
# size of target word dictionay
trg_vocab_size: 10000
# index for <bos> token
bos_idx: 0
# index for <eos> token
eos_idx: 1
# index for <unk> token
unk_idx: 2
# max length of sequences deciding the size of position encoding table.
max_length: 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model: 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid: 2048
# the dimension that keys are projected to for dot-product attention.
d_key: 64
# the dimension that values are projected to for dot-product attention.
d_value: 64
# number of head used in multi-head attention.
n_head: 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer: 6
# dropout rates of different modules.
prepostprocess_dropout: 0.1
attention_dropout: 0.1
relu_dropout: 0.1
# to process before each sub-layer
preprocess_cmd: "n" # layer normalization
# to process after each sub-layer
postprocess_cmd: "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing: True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import argparse
import json
import yaml
import six
import logging
logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class JsonConfig(object):
"""
A high-level api for handling json configure file.
"""
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
def _parse(self, config_path):
try:
with open(config_path) as json_file:
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
config_path)
else:
return config_dict
def __getitem__(self, key):
return self._config_dict[key]
def print_config(self):
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
"""
A high-level api for handling argument configs.
"""
def __init__(self):
parser = argparse.ArgumentParser()
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5,
"Learning rate used to train with warmup.")
train_g.add_arg(
"lr_scheduler",
str,
"linear_warmup_decay",
"scheduler of learning rate.",
choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01,
"Weight decay rate for L2 regularizer.")
train_g.add_arg(
"warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g.add_arg("save_steps", int, 1000,
"The steps interval to save checkpoints.")
train_g.add_arg("use_fp16", bool, False,
"Whether to use fp16 mixed precision training.")
train_g.add_arg(
"loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g.add_arg("pred_dir", str, None,
"Path to save the prediction results")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10,
"The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True,
"If set, use GPU for training.")
run_type_g.add_arg(
"use_fast_executor", bool, False,
"If set, use fast parallel executor (in experiment).")
run_type_g.add_arg(
"num_iteration_per_drop_scope", int, 1,
"Ihe iteration intervals to clean up temporary variables.")
run_type_g.add_arg("do_train", bool, True,
"Whether to perform training.")
run_type_g.add_arg("do_predict", bool, True,
"Whether to perform prediction.")
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
def print_arguments(args, log=None):
if not log:
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
else:
log.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
log.info('%s: %s' % (arg, value))
log.info('------------------------------------------------')
class PDConfig(object):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def __init__(self, json_file="", yaml_file="", fuse_args=True):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert isinstance(json_file, str)
assert isinstance(yaml_file, str)
if json_file != "" and yaml_file != "":
raise Warning(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self.args = None
self.arg_config = {}
self.json_config = {}
self.yaml_config = {}
parser = argparse.ArgumentParser()
self.default_g = ArgumentGroup(parser, "default", "default options.")
self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.")
self.json_g = ArgumentGroup(parser, "json", "options from json.")
self.com_g = ArgumentGroup(parser, "custom", "customized options.")
self.default_g.add_arg("do_train", bool, False,
"Whether to perform training.")
self.default_g.add_arg("do_predict", bool, False,
"Whether to perform predicting.")
self.default_g.add_arg("do_eval", bool, False,
"Whether to perform evaluating.")
self.default_g.add_arg("do_save_inference_model", bool, False,
"Whether to perform model saving for inference.")
self.parser = parser
if json_file != "":
self.load_json(json_file, fuse_args=fuse_args)
if yaml_file:
self.load_yaml(yaml_file, fuse_args=fuse_args)
def load_json(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.json_config = json.loads(fin.read())
fin.close()
if fuse_args:
for name in self.json_config:
if isinstance(self.json_config[name], list):
self.json_g.add_arg(
name,
type(self.json_config[name][0]),
self.json_config[name],
"This is from %s" % file_path,
nargs=len(self.json_config[name]))
continue
if not isinstance(self.json_config[name], int) \
and not isinstance(self.json_config[name], float) \
and not isinstance(self.json_config[name], str) \
and not isinstance(self.json_config[name], bool):
continue
self.json_g.add_arg(name,
type(self.json_config[name]),
self.json_config[name],
"This is from %s" % file_path)
def load_yaml(self, file_path, fuse_args=True):
if not os.path.exists(file_path):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
if fuse_args:
for name in self.yaml_config:
if isinstance(self.yaml_config[name], list):
self.yaml_g.add_arg(
name,
type(self.yaml_config[name][0]),
self.yaml_config[name],
"This is from %s" % file_path,
nargs=len(self.yaml_config[name]))
continue
if not isinstance(self.yaml_config[name], int) \
and not isinstance(self.yaml_config[name], float) \
and not isinstance(self.yaml_config[name], str) \
and not isinstance(self.yaml_config[name], bool):
continue
self.yaml_g.add_arg(name,
type(self.yaml_config[name]),
self.yaml_config[name],
"This is from %s" % file_path)
def build(self):
self.args = self.parser.parse_args()
self.arg_config = vars(self.args)
def __add__(self, new_arg):
assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
assert len(new_arg) >= 3
assert self.args is None
name = new_arg[0]
dtype = new_arg[1]
dvalue = new_arg[2]
desc = new_arg[3] if len(
new_arg) == 4 else "Description is not provided."
self.com_g.add_arg(name, dtype, dvalue, desc)
return self
def __getattr__(self, name):
if name in self.arg_config:
return self.arg_config[name]
if name in self.json_config:
return self.json_config[name]
if name in self.yaml_config:
return self.yaml_config[name]
raise Warning("The argument %s is not defined." % name)
def Print(self):
print("-" * 70)
for name in self.arg_config:
print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name])))
for name in self.json_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.json_config[name])))
for name in self.yaml_config:
if name not in self.arg_config:
print("%s:\t\t\t\t%s" %
(str(name), str(self.yaml_config[name])))
print("-" * 70)
if __name__ == "__main__":
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
pd_config = PDConfig(yaml_file="./test/bert_config.yaml")
pd_config += ("my_age", int, 18, "I am forever 18.")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
print(pd_config.my_age)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
from __future__ import print_function
import os
import six
import ast
import copy
import numpy as np
import paddle.fluid as fluid
class Placeholder(object):
def __init__(self):
self.shapes = []
self.dtypes = []
self.lod_levels = []
self.names = []
def __init__(self, input_shapes):
self.shapes = []
self.dtypes = []
self.lod_levels = []
self.names = []
for new_holder in input_shapes:
shape = new_holder[0]
dtype = new_holder[1]
lod_level = new_holder[2] if len(new_holder) >= 3 else 0
name = new_holder[3] if len(new_holder) >= 4 else ""
self.append_placeholder(
shape, dtype, lod_level=lod_level, name=name)
def append_placeholder(self, shape, dtype, lod_level=0, name=""):
self.shapes.append(shape)
self.dtypes.append(dtype)
self.lod_levels.append(lod_level)
self.names.append(name)
def build(self, capacity, reader_name, use_double_buffer=False):
pyreader = fluid.layers.py_reader(
capacity=capacity,
shapes=self.shapes,
dtypes=self.dtypes,
lod_levels=self.lod_levels,
name=reader_name,
use_double_buffer=use_double_buffer)
return [pyreader, fluid.layers.read_file(pyreader)]
def __add__(self, new_holder):
assert isinstance(new_holder, tuple) or isinstance(new_holder, list)
assert len(new_holder) >= 2
shape = new_holder[0]
dtype = new_holder[1]
lod_level = new_holder[2] if len(new_holder) >= 3 else 0
name = new_holder[3] if len(new_holder) >= 4 else ""
self.append_placeholder(shape, dtype, lod_level=lod_level, name=name)
class InputField(object):
"""
A high-level API for handling inputs in PaddlePaddle.
"""
def __init__(self, input_slots=[]):
self.shapes = []
self.dtypes = []
self.names = []
self.lod_levels = []
self.input_slots = {}
self.feed_list_str = []
self.feed_list = []
self.loader = None
if input_slots:
for input_slot in input_slots:
self += input_slot
def __add__(self, input_slot):
if isinstance(input_slot, list) or isinstance(input_slot, tuple):
name = input_slot[0]
shape = input_slot[1]
dtype = input_slot[2]
lod_level = input_slot[3] if len(input_slot) == 4 else 0
if isinstance(input_slot, dict):
name = input_slot["name"]
shape = input_slot["shape"]
dtype = input_slot["dtype"]
lod_level = input_slot[
"lod_level"] if "lod_level" in input_slot else 0
self.shapes.append(shape)
self.dtypes.append(dtype)
self.names.append(name)
self.lod_levels.append(lod_level)
self.feed_list_str.append(name)
return self
def __getattr__(self, name):
if name not in self.input_slots:
raise Warning("the attr %s has not been defined yet." % name)
return None
return self.input_slots[name]
def build(self, build_pyreader=False, capacity=100, iterable=False):
for _name, _shape, _dtype, _lod_level in zip(
self.names, self.shapes, self.dtypes, self.lod_levels):
self.input_slots[_name] = fluid.data(
name=_name, shape=_shape, dtype=_dtype, lod_level=_lod_level)
for name in self.feed_list_str:
self.feed_list.append(self.input_slots[name])
self.loader = fluid.io.DataLoader.from_generator(
feed_list=self.feed_list,
capacity=capacity,
iterable=(not build_pyreader),
use_double_buffer=True)
if __name__ == "__main__":
mnist_input_slots = [{
"name": "image",
"shape": (-1, 32, 32, 1),
"dtype": "int32"
}, {
"name": "label",
"shape": [-1, 1],
"dtype": "int64"
}]
input_field = InputField(mnist_input_slots)
input_field += {
"name": "large_image",
"shape": (-1, 64, 64, 1),
"dtype": "int32"
}
input_field += {
"name": "large_color_image",
"shape": (-1, 64, 64, 3),
"dtype": "int32"
}
input_field.build()
print(input_field.feed_list)
print(input_field.image)
print(input_field.large_color_image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册