未验证 提交 30b892e3 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #335 from Meiyim/dev

Dev
......@@ -3,21 +3,25 @@ English | [简体中文](./README.zh.md)
## ERNIE 2.0: A Continual Pre-training Framework for Language Understanding
* [Pre-training Tasks](#pre-training-tasks)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [Compare the ERNIE 1.0 and ERNIE 2.0](#compare-the-ernie-10-and-ernie-20)
* [Results on English Datasets](#results-on-english-datasets)
* [Results on Chinese Datasets](#results-on-chinese-datasets)
* [Pre-training Tasks](#pre-training-tasks)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [Compare the ERNIE 1.0 and ERNIE 2.0](#compare-the-ernie-10-and-ernie-20)
* [Results](#results)
* [Results on English Datasets](#results-on-english-datasets)
* [Results on Chinese Datasets](#results-on-chinese-datasets)
* [Release Notes](#release-notes)
* [Communication](#communication)
* [Usage](#usage)
![ernie2.0_paper](.metas/ernie2.0_paper.png)
......@@ -109,21 +113,6 @@ Integrating both phrase information and named entity information enables the mod
| **Structure-aware** | | ✅ Sentence Reordering | ✅ Sentence Reordering <br> ✅ Sentence Distance |
| **Semantic-aware** | ✅ Next Sentence Prediction | ✅ Discourse Relation | ✅ Discourse Relation <br> ✅ IR Relevance |
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
- Mar 15, 2019: release ERNIE 1.0
## Communication
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 760439550 (ERNIE discussion group).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Results
......@@ -626,6 +615,21 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u
BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
```
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
- Mar 15, 2019: release ERNIE 1.0
## Communication
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 760439550 (ERNIE discussion group).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Usage
* [Install PaddlePaddle](#install-paddlepaddle)
......@@ -645,7 +649,8 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv
* [Machine Reading Comprehension](#machine-reading-comprehension)
* [Pre-training with ERNIE 1.0](#pre-training-with-ernie-10)
* [Data Preprocessing](#data-preprocessing)
* [PreTrain ERNIE1.0](#pretrain-ernie10)
* [Pretrain ERNIE1.0](#pretrain-ernie10)
* [Distillation](#distillation)
* [FAQ](#faq)
* [FAQ1: How to get sentence/tokens embedding of ERNIE?](#faq1-how-to-get-sentencetokens-embedding-of-ernie)
* [FAQ2: How to predict on new data with Fine-tuning model?](#faq2-how-to-predict-on-new-data-with-fine-tuning-model)
......@@ -654,7 +659,7 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv
* [FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq5-can-not-find-library-libncclso-please-try-to-add-the-lib-path-to-ld_library_path)
## Install PaddlePaddle
### Install PaddlePaddle
This code base has been tested with Paddle Fluid 1.5.1 under Python2.
......@@ -671,11 +676,15 @@ If you have been armed with certain level of deep learning knowledge, and it hap
For more information about paddlepadde, Please refer to [PaddlePaddle Github](https://github.com/PaddlePaddle/Paddle) or [Official Website](https://www.paddlepaddle.org.cn/) for details.
Other dependency of ERNIE is listed in `requirements.txt`, you can install it by
```script
pip install -r requirements.txt
```
## Pre-trained Models & Datasets
### Pre-trained Models & Datasets
### Models
#### Models
| Model | Description |
| :------------------------------------------------- | :----------------------------------------------------------- |
......@@ -685,23 +694,23 @@ For more information about paddlepadde, Please refer to [PaddlePaddle Github](ht
| [ERNIE 2.0 Base for English](https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz) | with params, config and vocabs |
| [ERNIE 2.0 Large for English](https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz) | with params, config and vocabs |
### Datasets
#### Datasets
#### English Datasets
##### English Datasets
Download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) and unpack it to some directory `${TASK_DATA_PATH}`
After the dataset is downloaded, you should run `sh ./script/en_glue/preprocess/cvt.sh $TASK_DATA_PATH` to convert the data format for training. If everything goes well, there will be a folder named `glue_data_processed` created with all the converted datas in it.
#### Chinese Datasets
##### Chinese Datasets
You can download Chinese Datasets from [here](https://ernie.bj.bcebos.com/task_data_zh.tgz)
## Fine-tuning
#### Fine-tuning
### Batchsize and GPU Settings
##### Batchsize and GPU Settings
In our experiments, we found that the batch size is important for different tasks. For users can more easily reproducing results, we list the batch size and gpu cards here:
......@@ -728,7 +737,7 @@ In our experiments, we found that the batch size is important for different task
\* *For MNLI, QNLI,we used 32GB V100, for other tasks we used 22GB P40*
### Multiprocessing and fp16 auto mix-precision finetune
#### Multiprocessing and fp16 auto mix-precision finetune
multiprocessing finetuning can be simply enabled with `finetune_launch.py` in your finetune script.
with multiprocessing finetune paddle can fully utilize your CPU/GPU capacity to accelerate finetuning.
......@@ -738,9 +747,9 @@ fp16 finetuning can be simply enable by specifing `--use_fp16 true` in your trai
dynamic loss scale is used to avoid gradient vanish.
### Classification
#### Classification
#### Single Sentence Classification Tasks
##### Single Sentence Classification Tasks
The code used to perform classification/regression finetuning is in `run_classifier.py`, we also provide the shell scripts for each task including best hyperpameters.
......@@ -798,7 +807,7 @@ Similarly, for the Chinese task `ChnSentCorp`, after setting the environment var
#### Sentence Pair Classification Tasks
##### Sentence Pair Classification Tasks
Take `RTE` as an example, the data should have 3 fields `text_a text_b label` with tsv format. Here is some example datas:
```
......@@ -834,9 +843,9 @@ testing ./data/test.tsv, save to output/test_out.5.2019-07-23-15-25-06.tsv.4.781
### Sequence Labeling
#### Sequence Labeling
#### Named Entity Recognition
##### Named Entity Recognition
Take `MSRA-NER(SIGHAN2006)` as an example, the data should have 2 fields, `text_a label`, with tsv format. Here is some example datas :
```
......@@ -853,7 +862,7 @@ Also, remember to set environmental variables like above, and run `sh script/zh_
[test evaluation] f1: 0.937390, precision: 0.925988, recall: 0.949077, elapsed time: 36.565929 s
```
### Machine Reading Comprehension
#### Machine Reading Comprehension
Take `DRCD` as an example, convert the data into SQUAD format firstly:
......@@ -896,9 +905,9 @@ Also, remember to set environmental variables like above, and run `sh script/zh_
```
## Pre-training with ERNIE 1.0
### Pre-training with ERNIE 1.0
### Data Preprocessing
#### Data Preprocessing
We construct the training dataset based on [Baidu Baike](https://en.wikipedia.org/wiki/Baidu_Baike), [Baidu Knows(Baidu Zhidao)](https://en.wikipedia.org/wiki/Baidu_Knows), [Baidu Tieba](https://en.wikipedia.org/wiki/Baidu_Tieba) for Chinese version ERNIE, and [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia:Database_download), [Reddit](https://en.wikipedia.org/wiki/Reddit), [BookCorpus](https://github.com/soskek/bookcorpus) for English version ERNIE.
......@@ -912,7 +921,7 @@ Here are some train instances after processing (which can be found in [`data/dem
Each instance is composed of 5 fields, which are joined by `;`in one line, represented `token_ids; sentence_type_ids; position_ids; seg_labels; next_sentence_label` respectively. Especially, in the field`seg_labels`, 0 means the begin of one word, 1 means non-begin of one word, -1 means placeholder, the other number means `CLS` or `SEP`.
### PreTrain ERNIE 1.0
#### Pretrain ERNIE 1.0
The start entry for pretrain is [`script/zh_task/pretrain.sh`](./script/zh_task/pretrain.sh). Before we run the train program, remember to set CUDA、cuDNN、NCCL2 etc. in the environment variable LD_LIBRARY_PATH.
......@@ -932,10 +941,15 @@ epoch: 1, progress: 1/1, step: 50, loss: 10.360563, ppl: 16398.287109, next_sent
```
### Distillation
ERNIE provide a toolkit for data distillation to further accelerate your ineference, see <a href="./distill/README.md">here</a> for detail
## FAQ
### FAQ
### FAQ1: How to get sentence/tokens embedding of ERNIE?
#### FAQ1: How to get sentence/tokens embedding of ERNIE?
Run ```ernie_encoder.py ``` we can get the both sentence embedding and tokens embeddings. The input data format should be same as that mentioned in chapter [Fine-tuning](#fine-tuning).
......@@ -960,7 +974,7 @@ when finished running this script, `cls_emb.npy` and `top_layer_emb.npy `will b
### FAQ2: How to predict on new data with Fine-tuning model?
#### FAQ2: How to predict on new data with Fine-tuning model?
Take classification tasks for example, here is the script for batch prediction:
......@@ -984,18 +998,18 @@ Argument `init_checkpoint` is the path of the model, `predict_set` is the path
### FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?
#### FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?
For one GPU card.
### FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.
#### FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.
Export the path of cuda to LD_LIBRARY_PATH, e.g.: `export LD_LIBRARY_PATH=/home/work/cudnn/cudnn_v[your cudnn version]/cuda/lib64`
### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
#### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
Download [NCCL2](https://developer.nvidia.com/nccl/nccl-download), and export the library path to LD_LIBRARY_PATH, e.g.:`export LD_LIBRARY_PATH=/home/work/nccl/lib`
......@@ -3,21 +3,25 @@
## ERNIE 2.0: A Continual Pre-training Framework for Language Understanding
* [Pre-Training 任务](#pre-training-任务)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [对比 ERNIE 1.0 和 ERNIE 2.0](#对比-ernie-10-和-ernie-20)
* [中文效果验证](#中文效果验证)
* [英文效果验证](#英文效果验证)
* [Pre-Training 任务](#pre-training-任务)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [对比 ERNIE 1.0 和 ERNIE 2.0](#对比-ernie-10-和-ernie-20)
* [效果验证](#效果验证)
* [中文效果验证](#中文效果验证)
* [英文效果验证](#英文效果验证)
* [开源记录](#开源记录)
* [技术交流](#技术交流)
* [使用](#使用)
![ernie2.0_paper](.metas/ernie2.0_paper.png)
......@@ -105,26 +109,16 @@
| **Semantic-aware** | ✅ Next Sentence Prediction | ✅ Discourse Relation | ✅ Discourse Relation <br> ✅ IR Relevance |
## 开源记录
- 2019-07-30 发布 ERNIE 2.0
- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布
- 2019-03-18 更新: update ERNIE_stable.tgz
- 2019-03-15 发布 ERNIE 1.0
## 技术交流
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- ERNIE QQ 群: 760439550 (ERNIE discussion group).
- [论坛](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## 效果验证
## 中文效果验证
### 中文效果验证
我们在 9 个任务上验证 ERNIE 2.0 中文模型的效果。这些任务包括:自然语言推断任务 XNLI;阅读理解任务 DRCD、DuReader、CMRC2018;命名实体识别任务 MSRA-NER (SIGHAN2006);情感分析任务 ChnSentiCorp;语义相似度任务 BQ Corpus、LCQMC;问答任务 NLPCC2016-DBQA 。任务的详情和效果会在如下章节中介绍。
### 自然语言推断任务
#### 自然语言推断任务
<table>
<tbody>
......@@ -189,7 +183,7 @@
XNLI 是由 Facebook 和纽约大学的研究者联合构建的自然语言推断数据集,包括 15 种语言的数据。我们用其中的中文数据来评估模型的语言理解能力。[链接: https://github.com/facebookresearch/XNLI]
```
### 阅读理解任务
#### 阅读理解任务
<table>
<tbody>
......@@ -318,9 +312,7 @@ CMRC2018 是中文信息学会举办的评测,评测的任务是抽取类阅
DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从篇章中抽取出连续片段作为答案。我们在实验时先将其转换成简体中文。[链接: https://github.com/DRCKnowledgeTeam/DRCD]
```
### 命名实体识别任务
#### 命名实体识别任务
<table>
<tbody>
......@@ -377,9 +369,7 @@ DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从
MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。
```
### 情感分析任务
#### 情感分析任务
<table>
<tbody>
......@@ -436,9 +426,7 @@ MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文
ChnSentiCorp 是一个中文情感分析数据集,包含酒店、笔记本电脑和书籍的网购评论。
```
### 问答任务
#### 问答任务
<table>
<tbody>
......@@ -512,9 +500,7 @@ ChnSentiCorp 是一个中文情感分析数据集,包含酒店、笔记本电
NLPCC2016-DBQA 是由国际自然语言处理和中文计算会议 NLPCC 于 2016 年举办的评测任务,其目标是从候选中找到合适的文档作为问题的答案。[链接: http://tcci.ccf.org.cn/conference/2016/dldoc/evagline2.pdf]
```
### 语义相似度
#### 语义相似度
<table>
<tbody>
......@@ -597,14 +583,14 @@ BQ Corpus 是在自然语言处理国际顶会 EMNLP 2018 发布的语义匹配
## 英文效果验证
### 英文效果验证
ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址为 https://gluebenchmark.com/ ,该评测涵盖了不同类型任务的 10 个数据集,其中包含 11 个测试集,涉及到 Accuracy, F1-score, Spearman Corr,. Pearson Corr,. Matthew Corr., 5 类指标。GLUE 排行榜使用每个数据集的平均分作为总体得分,并以此为依据将不同算法进行排名。
### GLUE - 验证集结果
#### GLUE - 验证集结果
| <strong>数据集</strong> | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>QNLI</strong> | <strong>RTE</strong> |
| ----------- | ---- | ----- | ---- | ----- | ---- | ---- | ---- | ---- |
......@@ -617,7 +603,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
### GLUE - 测试集结果
#### GLUE - 测试集结果
| <strong>数据集</strong> | - | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>MNLI-mm</strong> | <strong>QNLI</strong> | <strong>RTE</strong> | <strong>WNLI</strong> |<strong>AX</strong>|
| ----------- | ----- | ---- | ----- | ---- | ----- | ---- | ------ | ------- | ---- | ---- | ---- | ---- |
......@@ -631,6 +617,19 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
由于 XLNet 暂未公布 GLUE 测试集上的单模型结果,所以我们只与 BERT 进行单模型比较。上表为ERNIE 2.0 单模型在 GLUE 测试集的表现结果。
## 开源记录
- 2019-07-30 发布 ERNIE 2.0
- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布
- 2019-03-18 更新: update ERNIE_stable.tgz
- 2019-03-15 发布 ERNIE 1.0
## 技术交流
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- ERNIE QQ 群: 760439550 (ERNIE discussion group).
- [论坛](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## 使用
* [PaddlePaddle 安装](#paddlepaddle安装)
* [模型&amp;数据](#模型数据)
......@@ -650,6 +649,10 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
* [预训练 (ERNIE 1.0)](#预训练-ernie-10)
* [数据预处理](#数据预处理)
* [开始训练](#开始训练)
* [蒸馏](#蒸馏)
* [上线](#上线)
* [生成inference_model](#生成inference_model)
* [在线预测](#在线预测)
* [FAQ](#faq)
* [FAQ1: 如何获取输入句子/词经过 ERNIE 编码后的 Embedding 表示?](#faq1-如何获取输入句子词经过-ernie-编码后的-embedding-表示)
* [FAQ2: 如何利用 Fine-tuning 得到的模型对新数据进行批量预测?](#faq2-如何利用-fine-tuning-得到的模型对新数据进行批量预测)
......@@ -672,6 +675,11 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
> - [训练神经网络](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/training/index_cn.html):介绍如何使用 Fluid 进行单机训练、多机训练、以及保存和载入模型变量
> - [模型评估与调试](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/evaluation_and_debugging/index_cn.html):介绍在 Fluid 下进行模型评估和调试的方法
ERNIE的其他依赖列在`requirements.txt`文件中,使用以下命令安装
```script
pip install -r requirements.txt
```
## 模型&数据
......@@ -918,6 +926,36 @@ epoch: 1, progress: 1/1, step: 50, loss: 10.360563, ppl: 16398.287109, next_sent
如果用自定义的真实数据进行训练,请参照[`script/zh_task/pretrain.sh`](./script/zh_task/pretrain.sh)脚本对参数做相应修改。
## 蒸馏
ERNIE提供了通过数据蒸馏从而达到模型压缩、加速的开发套件,具体开发流程请参考 <a href="./distill/README.md">这里</a>
## 上线
完成finetune之后只需几步操作即可生成inference\_model, PaddlePaddle可以在生产环境中加载生成的预测模型并进行高效地预测。
### 生成inference\_model
运行`classify_infer.py`或者`predict_classifier.py` 脚本时通过指定 `--save_inference_model_path` 便可生成 inference_model 到指定位置。
如果您采用 `propeller` 完成finetune,则 `BestInferenceExporter` 会在finetune过程中根据预测指标,挑最好的模型生成 inference_model . 使用 `propeller` 完成finetune的流程请参考 `propeller_xnli_demo.ipynb`
### 在线预测
随后您可以使用[PaddleInference C++ API](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/native_infer.html)将模型的前向预测代码联编到您的生产环境中。或者您可以使用我们为您构建好的python预测引擎来完成一个简单的服务。只需将本代码库中的 `./propeller` 文件夹放入您的 `PYTHONPATH` 中并执行如下指令,便可以开启一个propeller server:
```script
python -m propeller.tools.start_server -m /path/to/saved/model -p 8888
```
您可以在python脚本很方便地调用propeller server:
```python
from propeller.service.client import InferenceClient
client = InferenceClient('tcp://localhost:8113')
result = client(sentence_id, position_id, token_type_id, input_mask)
```
`client`的请求参数类型是numpy array,对应了save_inference_model时指定的输入tensor. 如果是使用`classify_infer.py` 生成的inference_model则请求参数有四个:(sentence_id, position_id, token_type_id, input_mask)。 如果是`propeller` 生成的inference_model, client的请求参数对应您`eval_dataset` 的元素类型。
## FAQ
### FAQ1: 如何获取输入句子/词经过 ERNIE 编码后的 Embedding 表示?
......@@ -981,3 +1019,4 @@ python -u predict_classifier.py \
### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
需要先下载 [NCCL](https://developer.nvidia.com/nccl/nccl-download),然后在 LD_LIBRARY_PATH 中添加 NCCL 库的路径,如`export LD_LIBRARY_PATH=/home/work/nccl/lib`
* [ERNIE Slim 数据蒸馏](#ernie-slim-数据蒸馏)
* [ERNIE数据蒸馏三步](#ernie数据蒸馏三步)
* [数据增强](#数据增强)
* [使用教程](#使用教程)
* [离线蒸馏](#离线蒸馏)
* [在线蒸馏](#在线蒸馏)
* [效果验证](#效果验证)
* [Case#1 用户提供“无标注数据”](#case1)
* [Case#2 用户未提供“无标注数据”](#case2)
* [FAQ](#faq)
# ERNIE Slim 数据蒸馏
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
<img src="http://agroup-bos.cdn.bcebos.com/ae16a29d6a334c74107cebcf56bc2419d385b364" title="ERNIE数据蒸馏示意图" width="900">
因此,如上图所示,我们基于[数据蒸馏技术](https://arxiv.org/pdf/1712.04440.pdf)构建了**ERNIE Slim数据蒸馏系统**。它的原理是通过数据作为桥梁,将ERNIE模型的知识迁移至小模型,以达到损失很小的效果却能达到上千倍的预测速度提升的效果。
### ERNIE数据蒸馏三步
- **Step 1**. 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model
- **Step 2**. 使用ERNIE Service对以下无监督数据进行预测:
1. 用户提供的大规模无标注数据,需与标注数据同源
2. 对标注数据进行数据增强,具体增强策略见下节
3. 对无标注数据和数据增强数据进行一定比例混合
- **Step 3.** 使用步骤2的数据训练出Student Model
### 数据增强
目前采用三种[数据增强策略](https://arxiv.org/pdf/1903.12136.pdf)策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:
1. 添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签
2. 同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词
3. N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值
# 使用教程
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从[这里](https://ernie.bj.bcebos.com/distill_data.tar.gz)下载。将下载的 `distill` 文件夹放入 `${TASK_DATA_PATH}` 后即可执行下面的脚本开始蒸馏。
### 离线蒸馏
离线蒸馏指的是先通过训练好的ERNIE模型预测出无监督数据的label,然后student模型去学习这些label。只需执行
```script
sh ./distill/script/distill_chnsenticorp.sh
```
即可开始离线蒸馏。
该脚本会进行前述的三步:1. 在任务数据上Fine-tune。 2. 加载Fine-tune好的模型对增强数据进行打分。 3.使用Student模型进行训练。脚本采用hard-label蒸馏,在第二步中将会直接预测出ERNIE标注的label。
该脚本涉及两个python文件:`./distill/finetune_chnsenticorp.py` 负责finetune以及预测teacher模型, `distill/distill_chnsentocorp.py` 负责student模型的训练。事先构造好的增强数据放在`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug`
在脚本的第二步中,使用 `--do_predict` 参数进入预测模式:
```script
cat ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 |python3 -u ./distill/finetune_chnsenticorp.py \
--do_predict \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
...
```
脚本从标准输入获取明文输入,并将打分输出到标准输出。用这种方式对数据增强后的无监督训练预料进行标注。最终的标注结果放在 `prediction_output/part.0` 文件中。标注结果包含两列, 第一列为明文,第二列为标注label。
在第三步开始student模型的训练:
```script
python3 ./distill/distill_chnsentocorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--unsupervise_data_dir ./prediction_output/ \
--max_seqlen 128 \
...
```
训练流程与第一步相同,`--data_dir` 指定的监督数据,`--unsupervise_data_dir` 指定ERNIE标注数据。Student模型是一个简单的BOW模型,其定义位于`distill/distill_chnsentocorp.py`。用户只需改写其中的model部分即可实现定制蒸馏模型。
如果用户已经拥有了无监督数据,则可以将无监督数据放入 `${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug` 即可。
### 在线蒸馏
考虑到在某些场景下,无监督数据过大导致预测过程十分耗时,或者ERNIE预测出的分布过大而无法预先存放在磁盘中。针对这种场景我们提出一种 **在线蒸馏** 方案。采用`propeller` 进行fine-tune并使用 `BestInferenceModelExporter` 后,`propeller` 会自动将指标最好的模型保存为paddle inference model格式,随后启动一个预测服务。Student模型在训练的同时,实时地访问这个服务来获得ERNIE的预测打分。只需执行
```
sh ./distill/script/distill_chnsenticorp_with_propeller_server.sh
```
即可完成上述流程。
流程包含3步:1. finetune ERNIE模型。2. 取指标最好的ERNIE模型启动`propeller`服务。 3.在student模型的训练过程中访问服务获取teacher模型的标注。
此流程涉及两个python文件: `distill/finetune_chnsenticorp.py``distill/distill_chnsentocorp_with_propeller_server.py` 。其中第一步与离线蒸馏中的用法完全一样。
第二步中使用
```script
python3 -m propeller.tools.start_server -p 8113 -m ${teacher_dir}/best/inference/ &
```
启动一个ernie预测服务
第三步开始student模型的同步训练:
```script
python3 ./distill/distill_chnsentocorp_with_propeller_server.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--teacher_vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--teacher_max_seqlen 128 \
--server_batch_size 64 \
--teacher_host tcp://localhost:8113 \
--num_coroutine 10
```
该脚本将`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug` 目录下的增强数据进行切字并请求`propeller` 服务。`--num_coroutine` 指定了请求的并发数,`--teacher_host` 指定了服务的端口和IP,`--server_batch_size` 指定了请求的batch_size,在实际的请求中每个batch的数据会拆分成若干个 `--server_batch_size` 大小的数据去请求服务。
# 效果验证
我们将实际应用场景分类为两种:
### Case#1 用户提供“无标注数据”<a name="case1"></a>
|模型 | 评论低质识别【分类 \| ACC】 | 中文情感【分类 \| ACC】 |问题识别【分类 \| ACC】|搜索问答匹配【匹配 \| 正逆序】|
|---|---|---|---|---|
|ERNIE-Finetune | 90.6% | 96.2% | 97.5% | 4.25 |
|非ERNIE基线(BOW)| 80.8% | 94.7% | 93.0% | 1.83 |
|**+ 数据蒸馏** | 87.2% | 95.8% | 96.3% | 3.30 |
### Case#2 用户未提供“无标注数据”(通过数据增强生成数据)<a name="case2"></a>
|模型 |ChnSentiCorp |
|---|---|
|ERNIE-Finetune |95.4% |
|非ERNIE基线(BOW)|90.1%|
|**+ 数据蒸馏** |91.4%|
|非ERNIE基线(LSTM)|91.2%|
|**+ 数据蒸馏**|93.9%|
# FAQ
### FQA1: 预测同时蒸馏报错:`Client call failed`
终端打印的错误是client的日志,server端的日志在前面。一般来说可能是server显存超限导致。这种时候需要在student模型finetune的脚本中使用`--server_batch_size ` 显示控制请求服务的batch大小。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
from random import random
from functools import reduce, partial
import logging
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import log
import propeller.paddle as propeller
from propeller.paddle.data import Dataset
from optimization import optimization
import utils.data
log.setLevel(logging.DEBUG)
class ClassificationBowModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, config, mode, run_config):
self.config = config
self.mode = mode
self.run_config = run_config
self._param_initializer = F.initializer.TruncatedNormal(
scale=config.initializer_range)
self._emb_dtype = "float32"
self._word_emb_name = "word_embedding"
def forward(self, features):
text_ids_a, = features
def bow(ids):
embed = L.embedding(
input=ids,
size=[self.config.vocab_size, self.config.emb_size],
dtype=self._emb_dtype,
param_attr=F.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
zero = L.fill_constant(shape=[1], dtype='int64', value=0)
pad = L.cast(L.logical_not(L.equal(ids, zero)), 'float32')
sumed = L.reduce_sum(embed * pad, dim=1)
sumed = L.softsign(sumed)
return sumed
sumed = bow(text_ids_a)
fced = L.fc(
input=sumed,
size=self.config.emb_size,
act='tanh',
param_attr=F.ParamAttr(
name="middle_fc.w_0", initializer=self._param_initializer),
bias_attr="middle_fc.b_0")
logits = L.fc(
input=fced,
size=self.config.num_label,
act=None,
param_attr=F.ParamAttr(
name="pooler_fc.w_0", initializer=self._param_initializer),
bias_attr="pooler_fc.b_0")
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
labels = L.softmax(labels)
loss = L.softmax_with_cross_entropy(predictions, labels, soft_label=True)
loss = L.mean(loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.config.warmup_proportion),
num_train_steps=self.run_config.max_steps,
learning_rate=self.config.learning_rate,
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.config.weight_decay,
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, labels):
predictions = L.argmax(predictions, axis=1)
labels = L.argmax(labels, axis=1)
#predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(labels, predictions)
#auc = propeller.metrics.Auc(labels, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--unsupervise_data_dir', type=str, required=True)
parser.add_argument('--data_dir', type=str)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
unk_id = vocab['[UNK]']
char_tokenizer = utils.data.CharTokenizer(vocab.keys())
space_tokenizer = utils.data.SpaceTokenizer(vocab.keys())
supervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.LabelColumn('label'),
])
def before(text_a, label):
sentence_a = text_a[: args.max_seqlen]
return sentence_a, label
def after(sentence_a, label):
batch_size = sentence_a.shape[0]
onehot_label = np.zeros([batch_size, hparams.num_label], dtype=np.float32)
onehot_label[np.arange(batch_size), label] = 9999.
sentence_a, = utils.data.expand_dims(sentence_a)
return sentence_a, onehot_label
train_ds = supervise_feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after) \
unsup_train_ds = supervise_feature_column.build_dataset('unsup_train', data_dir=args.unsupervise_data_dir, shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
dev_ds = supervise_feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
train_ds = utils.data.interleave(train_ds, unsup_train_ds)
shapes = ([-1, args.max_seqlen, 1], [-1, hparams.num_label])
types = ('int64', 'float32')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['acc'] > old['dev']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationBowModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset={'dev': dev_ds},
exporters=[best_exporter])
print('dev_acc3\t%.5f' % (best_exporter._best['dev']['acc']))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
from random import random
from functools import reduce, partial
import logging
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import log
import propeller.paddle as propeller
from propeller.paddle.data import Dataset
from propeller.service.client import InferenceClient
from optimization import optimization
import utils.data
log.setLevel(logging.DEBUG)
class ClassificationBowModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, config, mode, run_config):
self.config = config
self.mode = mode
self.run_config = run_config
self._param_initializer = F.initializer.TruncatedNormal(
scale=config.initializer_range)
self._emb_dtype = "float32"
self._word_emb_name = "word_embedding"
def forward(self, features):
text_ids_a, = features
def bow(ids):
embed = L.embedding(
input=ids,
size=[self.config.vocab_size, self.config.emb_size],
dtype=self._emb_dtype,
param_attr=F.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
zero = L.fill_constant(shape=[1], dtype='int64', value=0)
pad = L.cast(L.logical_not(L.equal(ids, zero)), 'float32')
sumed = L.reduce_sum(embed * pad, dim=1)
sumed = L.softsign(sumed)
return sumed
sumed = bow(text_ids_a)
fced = L.fc(
input=sumed,
size=self.config.emb_size,
act='tanh',
param_attr=F.ParamAttr(
name="middle_fc.w_0", initializer=self._param_initializer),
bias_attr="middle_fc.b_0")
logits = L.fc(
input=fced,
size=self.config.num_label,
act=None,
param_attr=F.ParamAttr(
name="pooler_fc.w_0", initializer=self._param_initializer),
bias_attr="pooler_fc.b_0")
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
labels = L.softmax(labels)
loss = L.softmax_with_cross_entropy(predictions, labels, soft_label=True)
loss = L.mean(loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.config.warmup_proportion),
num_train_steps=self.run_config.max_steps,
learning_rate=self.config.learning_rate,
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.config.weight_decay,
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, labels):
predictions = L.argmax(predictions, axis=1)
labels = L.argmax(labels, axis=1)
#predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(labels, predictions)
#auc = propeller.metrics.Auc(labels, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--teacher_vocab_file', type=str, required=True)
parser.add_argument('--teacher_max_seqlen', type=int, default=128)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--server_batch_size', type=int, default=64)
parser.add_argument('--num_coroutine', type=int, default=1)
parser.add_argument('--teacher_host', type=str, required=True)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
teacher_vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.teacher_vocab_file, 'rb'))}
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
teacher_sep_id = teacher_vocab['[SEP]']
teacher_cls_id = teacher_vocab['[CLS]']
teacher_unk_id = teacher_vocab['[UNK]']
unk_id = vocab['[UNK]']
char_tokenizer = utils.data.CharTokenizer(vocab.keys())
space_tokenizer = utils.data.SpaceTokenizer(vocab.keys())
supervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.LabelColumn('label'),
])
unsupervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.TextColumn('teacher_text_a', unk_id=teacher_unk_id, vocab_dict=teacher_vocab, tokenizer=char_tokenizer),
])
def before(text_a, label):
sentence_a = text_a[: args.max_seqlen]
return sentence_a, label
def after(sentence_a, label):
batch_size = sentence_a.shape[0]
onehot_label = np.zeros([batch_size, hparams.num_label], dtype=np.float32)
onehot_label[np.arange(batch_size), label] = 9999.
sentence_a, = utils.data.expand_dims(sentence_a)
return sentence_a, onehot_label
train_ds = supervise_feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after) \
dev_ds = supervise_feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
def unsuperve_before(text_a, teacher_text_a):
teacher_sentence, teacher_segments = utils.data.build_1_pair(teacher_text_a, max_seqlen=args.teacher_max_seqlen, cls_id=teacher_cls_id, sep_id=teacher_sep_id)
sentence_a = text_a[: args.max_seqlen]
return sentence_a, teacher_sentence, teacher_segments
client = InferenceClient(args.teacher_host, batch_size=args.server_batch_size, num_coroutine=args.num_coroutine)
log.info('teacher host %s' % args.teacher_host)
def ask_teacher_for_label(sentence_a, teacher_sentence, teacher_segments):
sentence_a, teacher_sentence, teacher_segments = utils.data.expand_dims(sentence_a, teacher_sentence, teacher_segments)
teacher_label, = client(teacher_sentence, teacher_segments)
teacher_label = teacher_label[:, :]
return sentence_a, teacher_label
unsup_train_ds = unsupervise_feature_column.build_dataset('unsup_train', data_dir=os.path.join(args.data_dir, 'unsup_train_aug'), shuffle=True, repeat=True, use_gz=False) \
.buffered(100) \
.map(unsuperve_before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(ask_teacher_for_label)
train_ds = utils.data.interleave(train_ds, unsup_train_ds)
shapes = ([-1, args.max_seqlen, 1], [-1, hparams.num_label])
types = ('int64', 'float32')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['acc'] > old['dev']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationBowModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset={'dev': dev_ds},
exporters=[best_exporter])
print('dev_acc3\t%.5f' % (best_exporter._best['dev']['acc']))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
import logging
from random import random
from functools import reduce, partial
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from model.ernie import ErnieModel
from optimization import optimization
import utils.data
from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
class ClassificationErnieModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, hparam, mode, run_config):
self.hparam = hparam
self.mode = mode
self.run_config = run_config
def forward(self, features):
src_ids, sent_ids = features
dtype = 'float16' if self.hparam['fp16'] else 'float32'
zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.logical_not(L.equal(src_ids, zero)), dtype) # assume pad id == 0
#input_mask = L.unsqueeze(input_mask, axes=[2])
d_shape = L.shape(src_ids)
seqlen = d_shape[1]
batch_size = d_shape[0]
pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0])
pos_ids = L.expand(pos_ids, [batch_size, 1])
pos_ids = L.unsqueeze(pos_ids, axes=[2])
pos_ids = L.cast(pos_ids, 'int64')
pos_ids.stop_gradient = True
input_mask.stop_gradient = True
task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment
task_ids.stop_gradient = True
bert = ErnieModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
task_ids=task_ids,
input_mask=input_mask,
config=self.hparam,
use_fp16=self.hparam['fp16']
)
cls_feats = bert.get_pooled_output()
cls_feats = L.dropout(
x=cls_feats,
dropout_prob=0.1,
dropout_implementation="upscale_in_train"
)
logits = L.fc(
input=cls_feats,
size=self.hparam['num_label'],
param_attr=F.ParamAttr(
name="cls_out_w",
initializer=F.initializer.TruncatedNormal(scale=0.02)),
bias_attr=F.ParamAttr(
name="cls_out_b", initializer=F.initializer.Constant(0.))
)
propeller.summary.histogram('pred', logits)
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
ce_loss, probs = L.softmax_with_cross_entropy(
logits=predictions, label=labels, return_softmax=True)
#L.Print(ce_loss, message='per_example_loss')
loss = L.mean(x=ce_loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.hparam['warmup_proportion']),
num_train_steps=self.run_config.max_steps,
learning_rate=self.hparam['learning_rate'],
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.hparam['weight_decay'],
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, label):
predictions = L.argmax(predictions, axis=1)
predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(label, predictions)
#auc = propeller.metrics.Auc(label, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--warm_start_from', type=str)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
sep_id = vocab['[SEP]']
cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]']
tokenizer = utils.data.CharTokenizer(vocab.keys())
def tokenizer_func(inputs):
'''avoid pickle error'''
ret = tokenizer(inputs)
return ret
if not args.do_predict:
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('title',unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),
propeller.data.LabelColumn('label'),
])
def before(seg_a, label):
sentence, segments = utils.data.build_1_pair(seg_a, max_seqlen=args.max_seqlen, cls_id=cls_id, sep_id=sep_id)
return sentence, segments, label
def after(sentence, segments, label):
sentence, segments, label = utils.data.expand_dims(sentence, segments, label)
return sentence, segments, label
log.debug(os.path.join(args.data_dir, 'train'))
train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(after)
dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(after)
shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1], [-1, 1])
types = ('int64', 'int64', 'int64')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*')
warm_start_dir = args.warm_start_from
ws = propeller.WarmStartSetting(
predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
from_dir=warm_start_dir
)
best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['eval']['acc'] > old['eval']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationErnieModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset=dev_ds,
warm_start_setting=ws,
exporters=[best_exporter])
print('dev_acc\t%.5f' % (best_exporter._best['eval']['acc']))
else:
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('title',unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),
propeller.data.LabelColumn('label'),
])
def before(seg_a):
sentence, segments = utils.data.build_1_pair(seg_a, max_seqlen=args.max_seqlen, cls_id=cls_id, sep_id=sep_id)
return sentence, segments
def after(sentence, segments):
sentence, segments = utils.data.expand_dims(sentence, segments)
return sentence, segments
predict_ds = feature_column.build_dataset_from_stdin('predict') \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1])
types = ('int64', 'int64')
predict_ds.data_shapes = shapes
predict_ds.data_types = types
finetuned_model = propeller.Learner(ClassificationErnieModel, run_config, hparams)
for logits, in finetuned_model.predict(predict_ds, ckpt=-1): # ckpt=-1 means last step
print(np.argmax(logits))
set -x
export PYTHONPATH=.:$PYTHONPATH
output_dir=./output/distill
teacher_dir=${output_dir}/teacher
student_dir=${output_dir}/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES=0 \
python3 -u ./distill/finetune_chnsenticorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"max_steps": '$((10 * 9600 / 32))',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
(($?!=0)) && echo "Something goes wrong at Step 1, please check" && exit -1
# 2. start a prediction server
export CUDA_VISIBLE_DEVICES=0
cat ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 |awk -F"\t" '{print $2}' |python3 -u ./distill/finetune_chnsenticorp.py \
--do_predict \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"log_steps": 10,
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 100
}' > prediction_label
(($?!=0)) && echo "Something goes wrong at Step 2, please check" && exit -1
mkdir prediction_output
paste ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 prediction_label |awk -F"\t" '{print $2"\t"$3}' > prediction_output/part.0
#. 3. learn from teacher
export CUDA_VISIBLE_DEVICES=0
python3 ./distill/distill_chnsentocorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--unsupervise_data_dir ./prediction_output/ \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${student_dir}'",
"max_steps": '$((100 * 9600 / 100))',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam '{
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}' \
--hparam '{ # lr shit
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
(($?!=0)) && echo "Something goes wrong at Step 3, please check" && exit -1
set -x
export PYTHONPATH=.:$PYTHONPATH
output_dir=./output/distill
teacher_dir=${output_dir}/teacher
student_dir=${output_dir}/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES=0 \
python3 -u ./distill/finetune_chnsenticorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"max_steps": '$((10 * 9600 / 32))',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
(($?!=0)) && echo "Something goes wrong at Step 1, please check" && exit -1
# 2. start a prediction server
CUDA_VISIBLE_DEVICES=1 \
python3 -m propeller.tools.start_server -p 8113 -m ${teacher_dir}/best/inference/ &
echo $! > pid.server
sleep 10
#. 3. learn from teacher
export CUDA_VISIBLE_DEVICES=0
python3 ./distill/distill_chnsentocorp_with_propeller_server.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--teacher_vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--teacher_max_seqlen 128 \
--server_batch_size 64 \
--teacher_host tcp://localhost:8113 \
--num_coroutine 10 \
--run_config '{
"model_dir": "'${student_dir}'",
"max_steps": '$((100 * 9600 / 100))',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam '{ # model definition
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
(($?!=0)) && echo "Something goes wrong at Step 2, please check" && exit -1
ps -ef|grep 'propeller.tools.start_server' |awk '{print $2}'|xargs kill -9
......@@ -212,6 +212,5 @@ def predict(exe,
except fluid.core.EOFException:
test_pyreader.reset()
break
log.info(len(res))
return res
......@@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.")
data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
......
......@@ -137,14 +137,20 @@ def start_procs(args):
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
logdir = "%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id)
try:
os.mkdir(os.path.dirname(logdir))
except OSError:
pass
fn = open(logdir, "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
log.info('subprocess launched, check log at %s' % logdir)
else:
process = subprocess.Popen(cmd, env=current_env)
log.info('subprocess launched')
log.info('subprocess launched')
procs.append(process)
try:
......
......@@ -139,12 +139,19 @@ def start_procs(args):
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
logdir = "%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id)
try:
os.mkdir(os.path.dirname(logdir))
except OSError:
pass
fn = open(logdir, "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
log.info('subprocess launched, check log at %s' % logdir)
else:
process = subprocess.Popen(cmd, env=current_env)
log.info('subprocess launched')
log.info('subprocess launched')
procs.append(process)
try:
......
[简体中文](./README.md)|English
# Introducing Propeller
This doc introduct Propeller, a high level paddle API for general ML, Propeller encapsulate the following actions::
- training
- evaluation
- prediction
- export serving
Propeller provide the following benefits:
- You can run Propeller-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Propeller-based models on CPUs, GPUs without recoding your model.
- Propeller simplify sharing implementations between model developers.
- Propeller do many things for you (logging, hot-start...)
- Propeller buids Program and PyReader or you.
- Propeller provide a safe distributed training loop that controls how and when to:
- build the Program
- initialize variables
- create checkpoint files and recover from failures
- save visualizable results
## install
```script
cd propeller && pip install .
```
## Getting Started
```python
#Define model
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size'])
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# hyper param comes from files/command line prompt/env vir
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# Define data
# `FeatureColumns` helps you to organize training/evluation files.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# Start training!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
More detail see example/toy/
## Main Feature
1. train_and_eval
according to user-specified `propeller.Model`class,initialize training model in the following 2 modes: 1. TRAIN mode 2. EVAL mode and
perform train_and_eval
2. FeatureColumns
`FeatureColumns`is used to ogranize train data. With custmizable `Column` property, it can adaps to many ML tasks(NLP/CV...).
`FeatureColumns` also do the preprocessing for you (tokenization, vocab lookup, serialization, batcing etc.)
3. Dataset
`FeatureColumns` generats `Dataset`,or you can call `propeller.Dataset.from_generator_func` to build your own `Dataset`.
4. Summary
To trace tensor histogram in training, simply:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing
1. This project is in alpha stage, any contribution is welcomed. Fill free to create a PR.
简体中文|[English](./README.en.md)
# Introducing paddle-propeller
本文档介绍propeller,一种可极大地简化机器学习编程的高阶 Paddle API。propeller 会封装下列操作:
- 训练
- 评估
- 预测
- 导出以供使用(上线)
Propeller 具有下列优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Propeller 的模型,而无需更改模型。此外,您可以在 CPU、GPU上运行基于 Propeller 的模型,而无需重新编码模型。
- Propeller 简化了在模型开发者之间共享实现的过程。
- 只需关注模型实现以及数据输入,而无需关注其他辅助代码(保存、热启动、打log等)
- Propeller 会为您构建Program以及PyReader。
- Propeller 提供安全的分布式训练循环,可以控制如何以及何时:
- 构建Program
- 初始化变量
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存可视化的摘要结果
## install|安装
cd propeller && pip install .
## Getting Started|快速开始
```python
#定义训练模型
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size']
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# 超参可以来自于文件/ 环境变量/ 命令行
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# 定义数据:
# `FeatureColumns` 用于管理训练、预测文件. 会自动进行二进制化.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# 开始训练!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
先洗详细请见example/toy/
## 主要构件
1. train_and_eval
会根据用户提供的`propeller.Model`类,实例化两种模式下的训练模型: 1. TRAIN模式 2. EVAL模式。
然后开始训练,同时执行评估(Evaluation)
2. FeatureColumns
`FeatureColumns`来管理训练数据. 根据自定义`Column`来适配多种ML任务(NLP/CV...).
`FeatureColumns`会自动对提供的训练数据进行批量预处理(tokenization, 查词表, etc.)并二进制化,并且生成训练用的dataset
3. Dataset
`FeatureColumns`生成`Dataset`,或者您可以调用`propeller.Dataset.from_generator_func`来构造自己的`Dataset`,配合shuffle/ interleave/ padded_batch/ repeat 等方法满足定制化需求.
4. Summary
对训练过程中的某些参数进行log追踪,只需要:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing|贡献
1. 本项目处于初期阶段,欢迎贡献!
2. functional programing is welcomed
## TODO
1. dataset output_types/ output_shapes 自动推断
2. 自动超参数搜索
3. propeller server
4. ...
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
import six
from time import time
__version__ = '0.1'
log = logging.getLogger(__name__)
stream_hdl = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
try:
from colorlog import ColoredFormatter
fancy_formatter = ColoredFormatter(
fmt='%(log_color)s[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
stream_hdl.setFormatter(fancy_formatter)
except ImportError:
stream_hdl.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(stream_hdl)
log.propagate = False
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools
import six
from six.moves import zip, map, filter
import numpy as np
from propeller.util import map_structure
log = logging.getLogger(__name__)
__all__ = ['Dataset']
@contextmanager
def open_file(filename, format=None):
if format is None:
fd = open(filename, 'rb')
elif format == 'GZIP':
fd = gzip.open(filename, 'rb')
else:
raise ValueError('unkwon file format %s' % format)
yield fd
fd.close()
def open_record(filename):
def gen():
with open_file(filename, format='GZIP') as f:
while True:
data = f.read(struct.calcsize('i'))
if not len(data):
raise StopIteration
l, = struct.unpack('i', data)
data = f.read(l)
yield data
return gen
def shuffle_func(dataset, buffer_size):
def gen():
buf = []
iterable = dataset()
try:
while len(buf) < buffer_size:
buf.append(next(iterable))
while 1:
i = random.randint(0, buffer_size - 1)
n = next(iterable)
yield buf[i]
buf[i] = n
except StopIteration:
if len(buf):
random.shuffle(buf)
for i in buf:
yield i
return gen
def interleave_func(iterable, map_fn, cycle_length, block_length):
def gen():
ls = itertools.tee(iterable(), cycle_length)
buf = []
for i, j in enumerate(ls):
j = itertools.islice(j, i, None, cycle_length)
j = map(map_fn, j)
j = (jjj for jj in j for jjj in jj) #flatten
buf.append(j)
for tup in six.moves.zip_longest(*buf):
for ii in (i for i in tup if i is not None):
yield ii
return gen
def repeat_func(dataset, n):
def gen():
iterable = dataset()
if n >= 0:
ret = itertools.chain(*itertools.tee(iterable, n))
else:
ret = itertools.cycle(iterable)
for i in ret:
yield i
return gen
def filter_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
if fn(*i) is True:
yield i
else:
if fn(i) is True:
yield i
return gen
def map_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
yield fn(*i)
else:
yield fn(i)
return gen
def shard_func(dataset, num_shards, index):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, index, None, num_shards)
for i in ret:
yield i
return gen
def take_func(dataset, count):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, count)
for i in ret:
yield i
return gen
def buffered_func(dataset, size):
"""
Creates a buffered data reader.
The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty.
:param reader: the data reader to read from.
:type reader: callable
:param size: max buffer size.
:type size: int
:returns: the buffered data reader.
"""
class EndSignal():
pass
end = EndSignal()
def read_worker(r, q):
for d in r:
q.put(d)
q.put(end)
def data_reader():
r = dataset()
q = multiprocessing.Queue(maxsize=size)
t = multiprocessing.Process(
target=read_worker, args=(
r,
q, ))
t.daemon = True
t.start()
e = q.get()
while e != end:
yield e
e = q.get()
return data_reader
def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if not isinstance(batch_size, int):
raise ValueError('unknown batch_size: %s' % repr(batch_size))
def gen():
iterable = dataset()
pad_value_t = pad_value
while True:
buf = list(itertools.islice(iterable, batch_size))
if not len(buf):
raise StopIteration
buf = list(zip(*buf)) # transpose
if type(pad_value_t) not in [list, tuple]:
pad_value_t = [pad_value_t] * len(buf)
padded = []
assert len(buf) == len(
pad_value_t), 'pad_value [%d] != element size[%d]' % (
len(pad_value_t), len(buf))
for e, pv in zip(buf, pad_value_t):
elem = e[0]
if (not np.isscalar(elem)) and elem.shape != ():
max_len = max(map(len,
e)) if max_seqlen is None else max_seqlen
e = map(lambda i: np.pad(i, [0, max_len - len(i)], 'constant', constant_values=pv) if max_len >= len(i) else i[: max_len], e)
padded.append(np.stack(list(e)))
yield padded
return gen
class Dataset(object):
@classmethod
def from_generator_func(cls, gen, data_shapes=None, data_types=None):
if not inspect.isgeneratorfunction(gen):
raise ValueError('expect generator function, got %s' % repr(gen))
def wrapper(): #compat to py3.7
try:
for item in gen():
yield item
except RuntimeError as e:
if str(e) != 'generator raised StopIteration':
raise e
ret = cls()
ret.generator = wrapper
ret.data_shapes = data_shapes
ret.data_types = data_types
return ret
@classmethod
def from_file(cls, filename, format=None):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
def gen():
with open_file(filename, format) as f:
for line in f:
yield line
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_record_file(cls, filename):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
gen = open_record(filename)
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_list(cls, ls):
if not isinstance(ls, list):
raise ValueError('expect list, got %s' % repr(ls))
def gen():
for i in ls:
yield i
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
def __init__(self):
self.name = None
self._data_shapes = None
self._data_types = None
self.generator = None
self.pyreader = None
def __repr__(self):
return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
self.name, self._data_shapes, self._data_types)
def __eq__(self, other):
return self.name == other.name and \
self._data_shapes == other._data_shapes and \
self._data_types == other._data_types
def __iter__(self):
return self.generator()
#def __call__(self):
# return self.generator()
def _infer_shapes_and_types(self):
if self.generator is not None and self.name is not None:
log.info('Try to infer data shapes & types from generator')
first_value = next(self.generator())
shapes, types = [], []
for v in first_value:
if not isinstance(v, np.ndarray):
raise ValueError(
'dataset generator should use numpy elements, got %s' %
first_value)
shapes.append(v.shape)
types.append(v.dtype.name)
self._data_shapes = shapes
self._data_types = types
log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
(self.name, repr(shapes), repr(types)))
else:
raise ValueError(
'Try to infer data shapes or types from incomplete Dataset')
@property
def data_shapes(self):
if self._data_shapes is None:
self._infer_shapes_and_types()
return self._data_shapes
else:
return self._data_shapes
@data_shapes.setter
def data_shapes(self, val):
self._data_shapes = val
@property
def data_types(self):
if self._data_types is None:
self._infer_shapes_and_types()
return self._data_types
else:
return self._data_types
@data_types.setter
def data_types(self, val):
self._data_types = val
def apply(self, transform_func):
#input_shapes = transform_func.input_shapes
#input_types = transform_func.input_types
#data_shapes = transform_func.data_shapes
#data_types = transform_func.data_types
#assert input_shapes == self._data_shapes
#assert input_types = self._data_types
ret_gen = transform_func(self.generator)
ret = type(self).from_generator_func(ret_gen)
if self.name is not None:
ret.name = self.name
#ret.data_shapes = data_shapes
#ret.data_types = data_types
return ret
def shuffle(self, buffer_size):
func = functools.partial(shuffle_func, buffer_size=buffer_size)
return self.apply(func)
def repeat(self, n=-1):
func = functools.partial(repeat_func, n=n)
return self.apply(func)
def map(self, fn):
func = functools.partial(map_func, fn=fn)
return self.apply(func)
def filter(self, fn):
func = functools.partial(filter_func, fn=fn)
return self.apply(func)
def shard(self, num_shards, index):
func = functools.partial(
shard_func, num_shards=num_shards, index=index)
return self.apply(func)
def interleave(self, map_fn, cycle_length, block_length):
func = functools.partial(
interleave_func,
map_fn=map_fn,
cycle_length=cycle_length,
block_length=block_length)
return self.apply(func)
def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
func = functools.partial(
padded_batch_func,
batch_size=batch_size,
pad_value=pad_value,
max_seqlen=max_seqlen)
return self.apply(func)
def take(self, count=1):
func = functools.partial(take_func, count=count)
return self.apply(func)
def buffered(self, size=10):
func = functools.partial(buffered_func, size=size)
return self.apply(func)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import six
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
from propeller.paddle import data
from propeller.paddle import train
from propeller.paddle.train import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
_global_collection = None
class Key(object):
"""predefine collection keys"""
SUMMARY_SCALAR = 1
SUMMARY_HISTOGRAM = 2
SKIP_OPTIMIZE = 3
class Collections(object):
"""global collections to record everything"""
def __init__(self):
self.col = {}
def __enter__(self):
global _global_collection
_global_collection = self
return self
def __exit__(self, err_type, err_value, trace):
global _global_collection
_global_collection = None
def add(self, key, val):
self.col.setdefault(key, []).append(val)
def get(self, key):
return self.col.get(key, None)
def default_collection():
global _global_collection
if _global_collection is None:
_global_collection = Collections()
return _global_collection
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
from propeller.paddle.data.functional import *
from propeller.paddle.data.feature_column import *
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "propeller/paddle/data/feature.proto";
package propeller;
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/example.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from propeller.paddle.data import feature_pb2 as propeller_dot_paddle_dot_data_dot_feature__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/example.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/example.proto\x12\tpropeller\x1a#propeller/paddle/data/feature.proto\"0\n\x07\x45xample\x12%\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x13.propeller.Features\"g\n\x0fSequenceExample\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.propeller.Features\x12.\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x17.propeller.FeatureListsb\x06proto3'
),
dependencies=[
propeller_dot_paddle_dot_data_dot_feature__pb2.DESCRIPTOR,
])
_EXAMPLE = _descriptor.Descriptor(
name='Example',
full_name='propeller.Example',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='features',
full_name='propeller.Example.features',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=87,
serialized_end=135, )
_SEQUENCEEXAMPLE = _descriptor.Descriptor(
name='SequenceExample',
full_name='propeller.SequenceExample',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='context',
full_name='propeller.SequenceExample.context',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='feature_lists',
full_name='propeller.SequenceExample.feature_lists',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=137,
serialized_end=240, )
_EXAMPLE.fields_by_name[
'features'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'context'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'feature_lists'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURELISTS
DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE
DESCRIPTOR.message_types_by_name['SequenceExample'] = _SEQUENCEEXAMPLE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Example = _reflection.GeneratedProtocolMessageType(
'Example',
(_message.Message, ),
dict(
DESCRIPTOR=_EXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.Example)
))
_sym_db.RegisterMessage(Example)
SequenceExample = _reflection.GeneratedProtocolMessageType(
'SequenceExample',
(_message.Message, ),
dict(
DESCRIPTOR=_SEQUENCEEXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.SequenceExample)
))
_sym_db.RegisterMessage(SequenceExample)
# @@protoc_insertion_point(module_scope)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package propeller;
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
map<string, Feature> feature = 1;
};
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
map<string, FeatureList> feature_list = 1;
};
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import struct
from six.moves import zip, map
import itertools
import gzip
from functools import partial
import multiprocessing
import six
import logging
import numpy as np
from glob import glob
from propeller.paddle.train import distribution
from propeller.data.functional import interleave_func
from propeller.paddle.data.functional import Dataset
from propeller.paddle.data import example_pb2, feature_pb2
log = logging.getLogger(__name__)
__all__ = [
'FeatureColumns', 'TextColumn', 'TextIDColumn', 'LabelColumn',
'basic_tokenizer', 'Column'
]
def basic_tokenizer(sen):
seg = sen.split(b' ')
seg = filter(lambda i: i != b' ', seg)
return seg
class Column():
def __init__(self, name):
pass
def raw_to_proto(self, raw):
return feature_pb2.Feature()
@property
def output_shapes(self):
pass
@property
def output_types(self):
pass
def proto_to_instance(self, proto):
raise NotImplementedError()
def raw_to_instance(self, raw):
raise NotImplementedError()
class LabelColumn(Column):
def __init__(self, name, vocab_dict=None, vocab_file=None):
self.name = name
self.vocab = None
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
if self.vocab is None:
ids = [int(raw)]
else:
ids = [self.vocab[raw]]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value[0], dtype=np.int64)
return ret
def raw_to_instance(self, raw):
if self.vocab is None:
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
class TextColumn(Column):
def __init__(self,
name,
unk_id,
vocab_file=None,
vocab_dict=None,
tokenizer=basic_tokenizer):
self.name = name
self.tokenizer = tokenizer
self.unk_id = unk_id
if not (vocab_file or vocab_dict):
raise ValueError('at least specify vocab_file or vocab_dict')
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
return np.array(ids, dtype=np.int64)
class TextIDColumn(Column):
def __init__(self, name):
self.name = name
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [int(s) for s in raw.split(b' ')]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64)
return ret
class FeatureColumns(object):
def __init__(self, columns, pad_id=0):
self._columns = columns
def raw_files(self, raw_dir):
return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)]
def gz_files(self, gz_dir):
return None if gz_dir is None else [
os.path.join(gz_dir, p) for p in os.listdir(gz_dir)
]
def _make_gz_dataset(self, raw_dir, gz_dir):
assert raw_dir or gz_dir, 'data_dir not specified when using gz mode'
if raw_dir is not None:
assert os.path.exists(raw_dir), 'raw_dir not exists: %s' % raw_dir
raw_file = os.listdir(raw_dir)
if gz_dir is None:
gz_dir = '%s_gz' % raw_dir.rstrip('/')
if not os.path.exists(gz_dir):
os.mkdir(gz_dir)
if raw_dir is not None:
if len(raw_file) != 0:
log.debug('try making gz')
pool = multiprocessing.Pool()
args = [(os.path.join(raw_dir, f), os.path.join(gz_dir, f),
self._columns, b'\t') for f in raw_file]
pool.map(_make_gz, args)
pool.terminate()
else:
assert len(
os.listdir(gz_dir)
) != 0, 'cant find gz file or raw-txt file at [%s] and [%s]' % (
raw_dir, gz_dir)
return gz_dir
def _read_gz_dataset(self,
gz_files,
shuffle=False,
repeat=True,
shard=False,
**kwargs):
if len(gz_files) == 0:
raise ValueError('reading gz from empty file list: %s' % gz_files)
log.info('reading gz from %s' % '\n'.join(gz_files))
dataset = Dataset.from_list(gz_files)
if repeat:
dataset = dataset.repeat()
if shard and distribution.status.mode == distribution.DistributionMode.NCCL:
log.info('Apply dataset sharding in distribution env')
train_ds = train_ds.shard(distribution.status.num_replica,
distribution.status.replica_id)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(gz_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_record_file(filename),
cycle_length=len(gz_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_gz(record_str): # function that takes python_str as input
ex = example_pb2.Example()
ex.ParseFromString(record_str)
ret = []
fea_dict = ex.features.feature
for c in self._columns:
ins = c.proto_to_instance(fea_dict[c.name])
ret.append(ins)
return ret
dataset = dataset.map(_parse_gz)
return dataset
def _read_txt_dataset(self,
data_files,
shuffle=False,
repeat=True,
**kwargs):
log.info('reading raw files from %s' % '\n'.join(data_files))
dataset = Dataset.from_list(data_files)
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(buffer_size=len(data_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_file(filename),
cycle_length=len(data_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_txt_file(
record_str): # function that takes python_str as input
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_txt_file)
return dataset
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs):
log.info('reading raw files stdin')
def gen():
if six.PY3:
source = sys.stdin.buffer
else:
source = sys.stdin
while True:
line = source.readline()
if len(line) == 0:
break
yield line,
dataset = Dataset.from_generator_func(gen)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_stdin(record_str):
'''function that takes python_str as input'''
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_stdin)
return dataset
def _prepare_dataset(self,
dataset,
map_func_before_batch=None,
map_func_after_batch=None,
shuffle_buffer_size=None,
batch_size=1,
pad_id=0,
prefetch=None,
**kwargs):
if map_func_before_batch is not None:
dataset = dataset.map(map_func_before_batch)
if batch_size:
dataset = dataset.padded_batch(batch_size, pad_id)
if map_func_after_batch is not None:
dataset = dataset.map(map_func_after_batch)
return dataset
def build_dataset(self,
name,
use_gz=True,
data_dir=None,
gz_dir=None,
data_file=None,
**kwargs):
if use_gz:
gz_dir = self._make_gz_dataset(data_dir, gz_dir)
gz_files = self.gz_files(gz_dir)
ds = self._read_gz_dataset(gz_files, **kwargs)
else:
if data_dir is not None:
data_files = self.raw_files(data_dir)
elif data_file is not None:
data_files = [data_file]
else:
raise ValueError('data_dir or data_files not specified')
ds = self._read_txt_dataset(data_files, **kwargs)
ds.name = name
return ds
def build_dataset_from_stdin(self, name, **kwargs):
ds = self._read_stdin_dataset(**kwargs)
ds.name = name
return ds
def _make_gz(args):
try:
from_file, to_file, columns, sep = args
if os.path.exists(to_file):
return
with open(from_file, 'rb') as fin, gzip.open(to_file, 'wb') as fout:
log.debug('making gz %s => %s' % (from_file, to_file))
for i, line in enumerate(fin):
line = line.strip(b'\n').split(sep)
#if i % 10000 == 0:
# log.debug('making gz %s => %s [%d]' % (from_file, to_file, i))
if len(line) != len(columns):
log.error('columns not match at %s, got %d, expect %d' %
(from_file, len(line), len(columns)))
continue
features = {}
for l, c in zip(line, columns):
features[c.name] = c.raw_to_proto(l)
example = example_pb2.Example(features=feature_pb2.Features(
feature=features))
serialized = example.SerializeToString()
l = len(serialized)
data = struct.pack('i%ds' % l, l, serialized)
fout.write(data)
log.debug('done making gz %s => %s' % (from_file, to_file))
except Exception as e:
log.exception(e)
raise e
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/feature.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/feature.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/feature.proto\x12\tpropeller\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\x95\x01\n\x07\x46\x65\x61ture\x12*\n\nbytes_list\x18\x01 \x01(\x0b\x32\x14.propeller.BytesListH\x00\x12*\n\nfloat_list\x18\x02 \x01(\x0b\x32\x14.propeller.FloatListH\x00\x12*\n\nint64_list\x18\x03 \x01(\x0b\x32\x14.propeller.Int64ListH\x00\x42\x06\n\x04kind\"\x81\x01\n\x08\x46\x65\x61tures\x12\x31\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32 .propeller.Features.FeatureEntry\x1a\x42\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.propeller.Feature:\x02\x38\x01\"2\n\x0b\x46\x65\x61tureList\x12#\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x12.propeller.Feature\"\x9a\x01\n\x0c\x46\x65\x61tureLists\x12>\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32(.propeller.FeatureLists.FeatureListEntry\x1aJ\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.propeller.FeatureList:\x02\x38\x01\x62\x06proto3'
))
_BYTESLIST = _descriptor.Descriptor(
name='BytesList',
full_name='propeller.BytesList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.BytesList.value',
index=0,
number=1,
type=12,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=50,
serialized_end=76, )
_FLOATLIST = _descriptor.Descriptor(
name='FloatList',
full_name='propeller.FloatList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FloatList.value',
index=0,
number=1,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=78,
serialized_end=108, )
_INT64LIST = _descriptor.Descriptor(
name='Int64List',
full_name='propeller.Int64List',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Int64List.value',
index=0,
number=1,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=110,
serialized_end=140, )
_FEATURE = _descriptor.Descriptor(
name='Feature',
full_name='propeller.Feature',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='bytes_list',
full_name='propeller.Feature.bytes_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='float_list',
full_name='propeller.Feature.float_list',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='int64_list',
full_name='propeller.Feature.int64_list',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='kind',
full_name='propeller.Feature.kind',
index=0,
containing_type=None,
fields=[]),
],
serialized_start=143,
serialized_end=292, )
_FEATURES_FEATUREENTRY = _descriptor.Descriptor(
name='FeatureEntry',
full_name='propeller.Features.FeatureEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.Features.FeatureEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Features.FeatureEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=358,
serialized_end=424, )
_FEATURES = _descriptor.Descriptor(
name='Features',
full_name='propeller.Features',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.Features.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURES_FEATUREENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=295,
serialized_end=424, )
_FEATURELIST = _descriptor.Descriptor(
name='FeatureList',
full_name='propeller.FeatureList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.FeatureList.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=426,
serialized_end=476, )
_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor(
name='FeatureListEntry',
full_name='propeller.FeatureLists.FeatureListEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.FeatureLists.FeatureListEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FeatureLists.FeatureListEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=559,
serialized_end=633, )
_FEATURELISTS = _descriptor.Descriptor(
name='FeatureLists',
full_name='propeller.FeatureLists',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature_list',
full_name='propeller.FeatureLists.feature_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURELISTS_FEATURELISTENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=479,
serialized_end=633, )
_FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST
_FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST
_FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'bytes_list'])
_FEATURE.fields_by_name[
'bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'float_list'])
_FEATURE.fields_by_name[
'float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'int64_list'])
_FEATURE.fields_by_name[
'int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE
_FEATURES_FEATUREENTRY.containing_type = _FEATURES
_FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY
_FEATURELIST.fields_by_name['feature'].message_type = _FEATURE
_FEATURELISTS_FEATURELISTENTRY.fields_by_name[
'value'].message_type = _FEATURELIST
_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS
_FEATURELISTS.fields_by_name[
'feature_list'].message_type = _FEATURELISTS_FEATURELISTENTRY
DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST
DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST
DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST
DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
DESCRIPTOR.message_types_by_name['Features'] = _FEATURES
DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST
DESCRIPTOR.message_types_by_name['FeatureLists'] = _FEATURELISTS
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
BytesList = _reflection.GeneratedProtocolMessageType(
'BytesList',
(_message.Message, ),
dict(
DESCRIPTOR=_BYTESLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.BytesList)
))
_sym_db.RegisterMessage(BytesList)
FloatList = _reflection.GeneratedProtocolMessageType(
'FloatList',
(_message.Message, ),
dict(
DESCRIPTOR=_FLOATLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FloatList)
))
_sym_db.RegisterMessage(FloatList)
Int64List = _reflection.GeneratedProtocolMessageType(
'Int64List',
(_message.Message, ),
dict(
DESCRIPTOR=_INT64LIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Int64List)
))
_sym_db.RegisterMessage(Int64List)
Feature = _reflection.GeneratedProtocolMessageType(
'Feature',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURE,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Feature)
))
_sym_db.RegisterMessage(Feature)
Features = _reflection.GeneratedProtocolMessageType(
'Features',
(_message.Message, ),
dict(
FeatureEntry=_reflection.GeneratedProtocolMessageType(
'FeatureEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURES_FEATUREENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features.FeatureEntry)
)),
DESCRIPTOR=_FEATURES,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features)
))
_sym_db.RegisterMessage(Features)
_sym_db.RegisterMessage(Features.FeatureEntry)
FeatureList = _reflection.GeneratedProtocolMessageType(
'FeatureList',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureList)
))
_sym_db.RegisterMessage(FeatureList)
FeatureLists = _reflection.GeneratedProtocolMessageType(
'FeatureLists',
(_message.Message, ),
dict(
FeatureListEntry=_reflection.GeneratedProtocolMessageType(
'FeatureListEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists.FeatureListEntry)
)),
DESCRIPTOR=_FEATURELISTS,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists)
))
_sym_db.RegisterMessage(FeatureLists)
_sym_db.RegisterMessage(FeatureLists.FeatureListEntry)
_FLOATLIST.fields_by_name['value']._options = None
_INT64LIST.fields_by_name['value']._options = None
_FEATURES_FEATUREENTRY._options = None
_FEATURELISTS_FEATURELISTENTRY._options = None
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.data.functional import Dataset as DatasetBase
log = logging.getLogger(__name__)
class Dataset(DatasetBase):
def placeholders(self):
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
ret = []
for i, (shape,
types) in enumerate(zip(self.data_shapes, self.data_types)):
ret.append(
L.data(
'%s_placeholder_%d' % (self.name, i),
shape=shape,
append_batch_size=False,
dtype=types))
return ret
def features(self):
'''start point of net building. call this in a program scope'''
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
if len(self.data_shapes) != len(self.data_types):
raise ValueError(
'Dataset shapes and types not match: shape:%s types%s' %
(repr(self._data_shapes), repr(self._data_types)))
return self.placeholders()
def start(self, places=F.cuda_places()):
#assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
def gen():
try:
for idx, i in enumerate(self.generator()):
yield i
except Exception as e:
log.exception(e)
raise e
r = F.io.PyReader(
feed_list=self.placeholders(), capacity=50, iterable=True)
r.decorate_batch_generator(gen, places=places)
return r()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import logging
import six
import asyncio
import threading
import grpc
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import propeller.paddle.service.utils as serv_utils
from concurrent.futures import ThreadPoolExecutor
import paddle.fluid as F
from time import sleep, time
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
def serve(model_dir, host, num_concurrent=None):
if six.PY2:
raise RuntimeError('propeller service work in python3 only')
num_worker = len(F.cuda_places(
)) if num_concurrent is None else num_concurrent
pool = ThreadPoolExecutor(num_worker)
class Predictor(object):
def __init__(self, did):
log.debug('create predictor on card %d' % did)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, did)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
predictor_context = {}
class InferenceService(interface_pb2_grpc.InferenceServicer):
@profile('service')
def Infer(self, request, context):
try:
slots = request.slots
current_thread = threading.current_thread()
log.debug('%d slots received dispatch to thread %s' %
(len(slots), current_thread))
if current_thread not in predictor_context:
did = list(pool._threads).index(current_thread)
log.debug('spawning worker thread %d' % did)
predictor = Predictor(did)
predictor_context[current_thread] = predictor
else:
predictor = predictor_context[current_thread]
slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
ret = predictor(slots)
response = [serv_utils.paddlearray_to_slot(r) for r in ret]
except Exception as e:
log.exception(e)
raise e
return interface_pb2.Slots(slots=response)
server = grpc.server(pool)
interface_pb2_grpc.add_InferenceServicer_to_server(InferenceService(),
server)
server.add_insecure_port(host)
server.start()
log.info('server started on %s...' % host)
try:
while True:
sleep(100000)
except KeyboardInterrupt as e:
pass
log.info('server stoped...')
if __name__ == '__main__':
from propeller import log
log.setLevel(logging.DEBUG)
serve(
'/home/work/chenxuyi/playground/grpc_play/ernie2.0/',
'10.255.138.19:8334',
num_concurrent=3)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import struct
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import paddle.fluid.core as core
def slot_to_paddlearray(slot):
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import paddle.fluid as F
from propeller.paddle.collection import default_collection, Key
def scalar(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_SCALAR, (name, tensor))
def histogram(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_HISTOGRAM, (name, tensor))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
from time import time
log = logging.getLogger(__name__)
from propeller.paddle.train.monitored_executor import *
from propeller.paddle.train.trainer import *
from propeller.paddle.train.hooks import *
from propeller.train.model import Model
from propeller.paddle.train import exporter
from propeller.paddle.train import distribution
from propeller.paddle.train import metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import functools
import six
import logging
from time import sleep
import paddle.fluid as F
import paddle.fluid.layers as L
log = logging.getLogger(__name__)
import propeller.util
__all__ = ['init_distribuition_env', 'status']
status = None
class DistributionMode(object):
LOCAL = 0
NCCL = 1
class DistributionStatus(object):
def __init__(self, config):
if config is None:
self._mode = DistributionMode.LOCAL
self._env = None
self._this = None
else:
try:
self._mode = DistributionMode.NCCL
cluster = config['cluster']
task = config['task']['type']
idx = int(config['task']['index'])
self._this = cluster[task][idx]
self._env = cluster['chief'] + cluster['worker']
if len(set(self._env)) != len(self._env):
raise ValueError('duplicate host in dis_config %s' %
config)
except KeyError as e:
raise ValueError(
'PROPELLER_DISCONFIG wrong: %s not found in %s' %
(e, repr(dis_config)))
@property
def mode(self):
return self._mode
@property
def num_replica(self):
if self._mode == DistributionMode.LOCAL:
return 1
elif self._mode == DistributionMode.NCCL:
return len(self._env)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def replica_id(self):
if self._mode == DistributionMode.LOCAL:
return 0
elif self._mode == DistributionMode.NCCL:
return self._env.index(self._this)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def is_master(self):
if self._mode == DistributionMode.LOCAL:
return True
elif self._mode == DistributionMode.NCCL:
return self.replica_id == 0
else:
raise ValueError('got unknow distribution mode %s' %
repr(self._mode))
dis_config = propeller.util._get_dict_from_environ_or_json_or_file(
None, 'PROPELLER_DISCONFIG')
status = DistributionStatus(dis_config)
def run_on_master(func):
"""skip function in distribution env"""
@functools.wraps(func)
def f(*arg, **kwargs):
"""f"""
if status is None:
raise ValueError('distribution mode unkown at this point')
if status.mode == DistributionMode.LOCAL:
r = func(*arg, **kwargs)
elif status.mode == DistributionMode.NCCL:
if status.is_master:
r = func(*arg, **kwargs)
else:
r = 0 # skip function
#MPI.COMM_WORLD.Barrier()
return r
return f
def init_distribuition_env(program):
if status.mode == DistributionMode.LOCAL:
log.info('Initializing local training')
elif status.mode == DistributionMode.NCCL:
config = F.DistributeTranspilerConfig()
config.mode = "nccl2"
F.DistributeTranspiler(config=config).transpile(
status.replica_id,
trainers=','.join(status._env),
current_endpoint=status._this,
program=program.train_program,
startup_program=program.startup_program)
log.info('Initializing distribution training with config %s' %
(repr(dis_config)))
if status.is_master:
sleep(30)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import os
import itertools
import six
import abc
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.paddle.train import Saver
from propeller.types import InferenceSpec
log = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class Exporter():
@abc.abstractmethod
def export(self, exe, program, eval_result, state):
raise NotImplementedError()
class BestExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
eval_program = program.train_program
# FIXME: all eval datasets has same name/types/shapes now!!! so every eval program are the smae
saver = Saver(
self._export_dir,
exe,
program=eval_program,
max_ckpt_to_keep=1)
saver.save(state)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
class BestInferenceModelExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
if eval_model_spec.inference_spec is None:
raise ValueError('model_fn didnt return InferenceSpec')
inf_sepc_dict = eval_model_spec.inference_spec
if not isinstance(inf_sepc_dict, dict):
inf_sepc_dict = {'inference': inf_sepc_dict}
for inf_sepc_name, inf_sepc in six.iteritems(inf_sepc_dict):
if not isinstance(inf_sepc, InferenceSpec):
raise ValueError('unkonw inference spec type: %s' % v)
save_dir = os.path.join(self._export_dir, inf_sepc_name)
log.debug('[Best Exporter]: save inference model: "%s" to %s' %
(inf_sepc_name, save_dir))
feed_var = [i.name for i in inf_sepc.inputs]
fetch_var = inf_sepc.outputs
eval_program = program.train_program
startup_prog = F.Program()
F.io.save_inference_model(
save_dir,
feed_var,
fetch_var,
exe,
main_program=eval_program)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import six
import os
import itertools
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import util
from propeller.paddle.train import distribution
from propeller.paddle.train.metrics import Metrics
__all__ = [
'RunHook', 'TqdmProgressBarHook', 'TqdmNotebookProgressBarHook',
'CheckpointSaverHook', 'LoggingHook', 'StopAtStepHook', 'EvalHook'
]
log = logging.getLogger(__name__)
class RunHook(object):
def __init__(self):
pass
def before_train(self):
pass
def before_run(self, state):
return []
def after_run(self, res_list, state):
pass
def should_stop(self, state):
return False
def after_train(self):
pass
class TqdmProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class TqdmNotebookProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
self.tqdm.refresh()
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class LoggingHook(RunHook):
def __init__(self,
loss,
per_step=10,
skip_step=100,
summary_writer=None,
summary_record=None):
if per_step is None or skip_step is None:
raise ValueError('wrong step argument, per step: %d skip_step %d' %
(per_step, skip_step))
self.loss = loss
self.per_step = per_step
self.skip_step = skip_step
self.summary_record = summary_record
self.writer = summary_writer
self.last_state = None
def before_train(self):
if self.summary_record:
if self.summary_record.scalar:
self.s_name, self.s_tolog = zip(*self.summary_record.scalar)
else:
self.s_name, self.s_tolog = [], []
if self.summary_record.histogram:
self.h_name, self.h_tolog = zip(*self.summary_record.histogram)
else:
self.h_name, self.h_tolog = [], []
def before_run(self, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
ret = [self.loss]
if self.summary_record:
ret += self.s_tolog
ret += self.h_tolog
return ret
else:
return []
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
if not self.summary_record:
return
loss = float(res_list[0])
s_np = res_list[1:1 + len(self.s_name)]
h_np = res_list[1 + len(self.s_name):1 + len(self.s_name) + len(
self.h_name)]
if self.last_state is not None:
speed = (state.gstep - self.last_state.gstep) / (
state.time - self.last_state.time)
else:
speed = -1.
self.last_state = state
# log to tensorboard
if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_scalar(name, t, state.gstep)
for name, t in zip(self.h_name, h_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_histogram(name, t, state.gstep)
if speed > 0.:
self.writer.add_scalar('global_step', speed, state.gstep)
# log to stdout
log.debug('\t'.join([
'step: %d' % state.gstep,
'steps/sec: %.5f' % speed,
'loss: %.5f' % loss,
'' if self.summary_record is None else ' '.join(
map(lambda t: '%s:%s' % t, zip(self.s_name, s_np))),
]))
class StopAtStepHook(RunHook):
def __init__(self, stop_global_step, stop_step):
self._stop_gstep = stop_global_step
self._stop_step = stop_step
def should_stop(self, state):
if (self._stop_gstep and state.gstep >= self._stop_gstep) or \
(self._stop_step and state.step >= self._stop_step):
log.info('StopAtStepHook called stop')
return True
else:
return False
class EvalHook(RunHook):
"""hook this on a eval Executor"""
def __init__(self, metrics, summary_writer=None):
self.writer = summary_writer
self._result = None
if not isinstance(metrics, dict):
raise ValueError('metrics should be dict, got %s' % repr(metrics))
for k, m in six.iteritems(metrics):
if not isinstance(m, Metrics):
raise ValueError(
'metrics %s should be instance of propeller.Metrics, got %s'
% (k, repr(m)))
if len(metrics):
self.names = list(metrics.keys())
self.metrics = list(metrics.values())
else:
self.names, self.metrics = [], []
def before_train(self):
for m in self.metrics:
m.reset()
def before_run(self, state):
ls = [m.tensor for m in self.metrics]
for i in ls:
if not (isinstance(i, list) or isinstance(i, tuple)):
raise ValueError(
'metrics should return tuple or list of tensors, got %s' %
repr(i))
for ii in i:
if not isinstance(ii, F.framework.Variable):
raise ValueError(
'metrics tensor be propeller.train.Metrics, got %s of type %s'
% (repr(ii), type(ii)))
ls_flt, self.schema = util.flatten(ls)
#log.debug(ls_flt)
return ls_flt
def after_run(self, res_list, state):
res = util.unflatten(res_list, self.schema)
for r, m in zip(res, self.metrics):
m.update(r)
@property
def result(self):
return self._result
def after_train(self):
printable = []
self._result = {}
for n, m in zip(self.names, self.metrics):
val = m.eval()
self._result[n] = val
return self.result
class CheckpointSaverHook(RunHook):
def __init__(self, saver, per_step=10, skip_step=100):
self.saver = saver
self.per_step = per_step
self.skip_step = skip_step
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and \
state.step > self.skip_step:
self.saver.save(state)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import itertools
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
import sklearn.metrics
log = logging.getLogger(__name__)
__all__ = [
'Metrics', 'F1', 'Recall', 'Precision', 'Mrr', 'Mean', 'Acc', 'ChunkF1',
'RecallAtPrecision'
]
class Metrics(object):
def __init__(self):
self.saver = []
@property
def tensor(self):
pass
def update(self, *args):
pass
def eval(self):
pass
class Mean(Metrics):
def __init__(self, t):
self.t = t
self.reset()
def reset(self):
self.saver = np.array([])
@property
def tensor(self):
self.t.persistable = True
return self.t,
def update(self, args):
t, = args
t = t.reshape([-1])
self.saver = np.concatenate([self.saver, t])
def eval(self):
return self.saver.mean()
class Ppl(Mean):
def eval(self):
return np.exp(self.saver.mean())
class Acc(Mean):
def __init__(self, label, pred):
self.eq = L.equal(pred, label)
self.reset()
@property
def tensor(self):
self.eq.persistable = True
return self.eq,
class MSE(Mean):
def __init__(self, label, pred):
diff = pred - label
self.mse = diff * diff
self.reset()
@property
def tensor(self):
self.mse.persistable = True
return self.mse,
class Cosine(Mean):
def __init__(self, label, pred):
self.cos = L.cos_sim(label, pred)
self.reset()
@property
def tensor(self):
self.cos.persistable = True
return self.cos,
class Precision(Metrics):
def __init__(self, label, pred):
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.label.persistable = True
self.pred.persistable = True
return self.label, self.pred
def update(self, args):
label, pred = args
label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool)
if label.shape != pred.shape:
raise ValueError(
'Metrics precesion: input not match: label:%s pred:%s' %
(label, pred))
self.label_saver = np.concatenate([self.label_saver, label])
self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
return tp / t
class Recall(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
p = (self.label_saver).astype(np.int64).sum()
return tp / p
class F1(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum()
precision = tp / (t + 1.e-6)
recall = tp / (p + 1.e-6)
return 2 * precision * recall / (precision + recall + 1.e-6)
class Auc(Metrics):
def __init__(self, label, pred):
self.pred = pred
self.label = label
self.reset()
def reset(self):
self.pred_saver = np.array([], dtype=np.float32)
self.label_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
return [self.pred, self.label]
def update(self, args):
pred, label = args
pred = pred.reshape([-1]).astype(np.float32)
label = label.reshape([-1]).astype(np.bool)
self.pred_saver = np.concatenate([self.pred_saver, pred])
self.label_saver = np.concatenate([self.label_saver, label])
def eval(self):
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_saver.astype(np.int64), self.pred_saver)
auc = sklearn.metrics.auc(fpr, tpr)
return auc
class RecallAtPrecision(Auc):
def __init__(self, label, pred, precision=0.9):
super(RecallAtPrecision, self).__init__(label, pred)
self.precision = precision
def eval(self):
self.pred_saver = self.pred_saver.reshape(
[self.label_saver.size, -1])[:, -1]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
self.label_saver, self.pred_saver)
for p, r in zip(precision, recall):
if p > self.precision:
return r
class PrecisionAtThreshold(Auc):
def __init__(self, label, pred, threshold=0.5):
super().__init__(label, pred)
self.threshold = threshold
def eval(self):
infered = self.pred_saver > self.threshold
correct_num = np.array(infered & self.label_saver).sum()
infer_num = infered.sum()
return correct_num / (infer_num + 1.e-6)
class Mrr(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.qid_saver = np.array([], dtype=np.int64)
self.label_saver = np.array([], dtype=np.int64)
self.pred_saver = np.array([], dtype=np.float32)
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError(
'Mrr dimention not match: qid[%s] label[%s], pred[%s]' %
(qid.shape, label.shape, pred.shape))
self.qid_saver = np.concatenate(
[self.qid_saver, qid.reshape([-1]).astype(np.int64)])
self.label_saver = np.concatenate(
[self.label_saver, label.reshape([-1]).astype(np.int64)])
self.pred_saver = np.concatenate(
[self.pred_saver, pred.reshape([-1]).astype(np.float32)])
def eval(self):
def key_func(tup):
return tup[0]
def calc_func(tup):
ranks = [
1. / (rank + 1.)
for rank, (_, l, p) in enumerate(
sorted(
tup, key=lambda t: t[2], reverse=True)) if l != 0
]
ranks = ranks[0]
return ranks
mrr_for_qid = [
calc_func(tup)
for _, tup in itertools.groupby(
sorted(
zip(self.qid_saver, self.label_saver, self.pred_saver),
key=key_func),
key=key_func)
]
mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid))
return mrr
class ChunkF1(Metrics):
def __init__(self, label, pred, seqlen, num_label):
self.label = label
self.pred = pred
self.seqlen = seqlen
self.null_index = num_label - 1
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
def _extract_bio_chunk(self, seq):
chunks = []
cur_chunk = None
for index in range(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == self.null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
def reset(self):
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
self.seqlen.persistable = True
return [self.pred, self.label, self.seqlen]
def update(self, args):
pred, label, seqlen = args
pred = pred.reshape([-1]).astype(np.int32).tolist()
label = label.reshape([-1]).astype(np.int32).tolist()
seqlen = seqlen.reshape([-1]).astype(np.int32).tolist()
max_len = 0
for l in seqlen:
max_len = max(max_len, l)
for i in range(len(seqlen)):
seq_st = i * max_len + 1
seq_en = seq_st + (seqlen[i] - 2)
pred_chunks = self._extract_bio_chunk(pred[seq_st:seq_en])
label_chunks = self._extract_bio_chunk(label[seq_st:seq_en])
self.pred_cnt += len(pred_chunks)
self.label_cnt += len(label_chunks)
pred_index = 0
label_index = 0
while label_index < len(label_chunks) and pred_index < len(
pred_chunks):
if pred_chunks[pred_index]['st'] < label_chunks[label_index][
'st']:
pred_index += 1
elif pred_chunks[pred_index]['st'] > label_chunks[label_index][
'st']:
label_index += 1
else:
if pred_chunks[pred_index]['en'] == label_chunks[label_index]['en'] \
and pred_chunks[pred_index]['type'] == label_chunks[label_index]['type']:
self.correct_cnt += 1
pred_index += 1
label_index += 1
def eval(self):
if self.pred_cnt == 0:
precision = 0.0
else:
precision = 1.0 * self.correct_cnt / self.pred_cnt
if self.label_cnt == 0:
recall = 0.0
else:
recall = 1.0 * self.correct_cnt / self.label_cnt
if self.correct_cnt == 0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return np.float32(f1)
class PNRatio(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
for i in range(0, len(outputs)):
l1, p1 = outputs[i]
for j in range(i + 1, len(outputs)):
l2, p2 = outputs[j]
if l1 > l2:
if p1 > p2:
p += 1
elif p1 < p2:
n += 1
elif l1 < l2:
if p1 < p2:
p += 1
elif p1 > p2:
n += 1
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class BinaryPNRatio(PNRatio):
def __init__(self, qid, label, pred):
super(BinaryPNRatio, self).__init__(qid, label, pred)
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
pos_set = []
neg_set = []
for label, score in outputs:
if label == 1:
pos_set.append(score)
else:
neg_set.append(score)
for ps in pos_set:
for ns in neg_set:
if ps > ns:
p += 1
elif ps < ns:
n += 1
else:
continue
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class PrecisionAtK(Metrics):
def __init__(self, qid, label, pred, k=1):
self.qid = qid
self.label = label
self.pred = pred
self.k = k
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
right = 0
total = 0
for v in self.saver.values():
v = sorted(v, key=lambda x: x[1], reverse=True)
k = min(self.k, len(v))
for i in range(k):
if v[i][0] == 1:
right += 1
break
total += 1
return np.float32(1.0 * right / total)
#class SemanticRecallMetrics(Metrics):
# def __init__(self, qid, vec, type_id):
# self.qid = qid
# self.vec = vec
# self.type_id = type_id
# self.reset()
#
# def reset(self):
# self.saver = []
#
# @property
# def tensor(self):
# return [self.qid, self.vec, self.type_id]
#
# def update(self, args):
# qid, vec, type_id = args
# self.saver.append((qid, vec, type_id))
#
# def eval(self):
# dic = {}
# for qid, vec, type_id in self.saver():
# dic.setdefault(i, {}).setdefault(k, []).append(vec)
#
# for qid in dic:
# assert len(dic[qid]) == 3
# qvec = np.arrray(dic[qid][0])
# assert len(qvec) == 1
# ptvec = np.array(dic[qid][1])
# ntvec = np.array(dic[qid][2])
#
# np.matmul(qvec, np.transpose(ptvec))
# np.matmul(qvec, np.transpose(ntvec))
#
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import json
from functools import reduce
import six
from time import time
import shutil
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import util
from propeller.types import StopException, ProgramPair
from propeller.paddle.train import hooks
from . import distribution
log = logging.getLogger(__name__)
__all__ = ['MonitoredExecutor', 'Saver']
class RunState(object):
@classmethod
def from_str(cls, s):
j = json.loads(s)
ret = RunState()
ret._gstep = j['global_step']
ret._time = j['time']
ret._step = 0
return ret
def __init__(self):
self._gstep = 0
self._step = 0
self._time = time()
@property
def gstep(self):
return self._gstep
@property
def step(self):
return self._step
@property
def time(self):
return self._time
def __repr__(self):
return repr({'global_step': self._gstep, 'time': self._time})
def serialize(self):
return json.dumps({'global_step': self._gstep, 'time': self._time})
def next(self):
ret = RunState()
ret._gstep = self._gstep + 1
ret._step = self._step + 1
ret._time = time()
return ret
class Saver(object):
def __init__(self,
save_dir,
exe,
program,
save_prefix='model',
max_ckpt_to_keep=None):
if exe is not None:
assert isinstance(
exe, F.Executor
), 'expect normal executor to save, got executor of type %s' % repr(
type(exe))
self._exe = exe
self._program = program
self._save_dir = save_dir
self._save_prefix = save_prefix
self._max_ckpt_to_keep = 10 if max_ckpt_to_keep is None else max_ckpt_to_keep
self.ckpt_info_path = os.path.join(save_dir, 'ckpt_info')
if os.path.exists(self.ckpt_info_path):
self.ckpt_list = [
p.strip() for p in open(self.ckpt_info_path).readlines()
]
log.debug('ckpt_list in this Saver: %s' % (self.ckpt_list))
else:
self.ckpt_list = []
@property
def last_ckpt(self):
return self.ckpt_list[-1] if len(self.ckpt_list) else None
def save(self, state):
save_name = '%s_%d' % (self._save_prefix, state.gstep)
save_dir = os.path.join(self._save_dir, save_name)
tmp_dir = os.path.join(self._save_dir, 'tmp')
try:
shutil.rmtree(save_dir)
shutil.rmtree(tmp_dir)
except OSError:
pass
log.debug('saving step %d to %s' % (state.gstep, save_dir))
F.io.save_persistables(self._exe, tmp_dir, self._program)
shutil.move(tmp_dir, save_dir)
meta = state.serialize()
open(os.path.join(save_dir, 'meta'), 'w').write(meta)
self.ckpt_list.append(save_name)
if len(self.ckpt_list) > self._max_ckpt_to_keep:
ckpt_to_keep = self.ckpt_list[-self._max_ckpt_to_keep:]
ckpt_to_remove = set(self.ckpt_list) - set(ckpt_to_keep)
self.ckpt_list = ckpt_to_keep
for ckpt in ckpt_to_remove:
ckpt_dir = os.path.join(self._save_dir, ckpt)
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
log.debug('No. of ckpt exceed %d, clean up: %s' %
(self._max_ckpt_to_keep, ckpt_dir))
open(self.ckpt_info_path, 'w').write('\n'.join(self.ckpt_list))
def restore(self, ckpt=-1):
if not isinstance(ckpt, (int, ) + six.string_types):
raise ValueError('ckpt type not understood %s' % repr(ckpt))
if isinstance(ckpt, int):
try:
ckpt = self.ckpt_list[ckpt]
except IndexError:
raise ValueError('invalid restore ckpt number %d' % ckpt)
if isinstance(ckpt, six.string_types):
try:
ckpt = self.ckpt_list.index(ckpt)
except ValueError:
raise ValueError('ckpt: %s not in ckpt list: %s' %
(ckpt, self.ckpt_list))
path = os.path.join(self._save_dir, self.ckpt_list[ckpt])
meta_file = os.path.join(path, 'meta')
if not os.path.exists(meta_file):
raise RuntimeError('meta not found in restore dir: %s' % path)
state = RunState.from_str(open(meta_file).read())
log.info('restore from ckpt %s, ckpt-status: %s' % (path, repr(state)))
def fn(v):
vpath = os.path.join(path, v.name)
if F.io.is_persistable(v):
if os.path.exists(vpath):
return True
else:
log.warning('var %s not found in checkpoint, ignored' %
v.name)
return False
F.io.load_vars(
self._exe, path, main_program=self._program, predicate=fn)
return state
class MonitoredExecutor(object):
"""A wrapper handling the train loop"""
def __init__(
self,
executor,
program,
loss=None, #must set in train
state=None,
run_config=None, #none if not load
run_hooks=[],
warm_start_setting=None):
if not isinstance(executor, F.Executor):
raise ValueError('PE is no longer supported')
if isinstance(executor, F.ParallelExecutor):
raise ValueError('ParallelExecutor is deprecatd, use Executor')
self._exe = executor
self._hooks = run_hooks
self._state = RunState() # might be overwrite in freeze
self._program = program
self._loss = loss
self._warm_start_setting = warm_start_setting
self._saver = None # will set in prepare
self.result = None # will set after train
if run_config is not None:
self._model_dir = run_config.model_dir
self._save_dir = run_config.model_dir
self._save_steps = run_config.save_steps
self._skip_steps = run_config.skip_steps if run_config.skip_steps else 100
self._save_prefix = 'model'
self._max_ckpt = run_config.max_ckpt
@property
def state(self):
return self._state
def init_or_restore_variables(self):
# The order of this 2 steps really matters
# 1. init train
F.Executor(F.cuda_places()[0]).run(self._program.startup_program)
# 2. restore param
if self._warm_start_setting is not None:
if not os.path.exists(self._warm_start_setting.from_dir):
raise ValueError('warm start dir not exists: %s' %
self._warm_start_setting.from_dir)
log.info("warm start from %s" % self._warm_start_setting.from_dir)
if self._warm_start_setting.predicate_fn is not None:
def fn(v):
ret = self._warm_start_setting.predicate_fn(v)
if ret:
log.info('warm start: %s' % v.name)
return ret
F.io.load_vars(
F.Executor(F.cuda_places()[0]),
self._warm_start_setting.from_dir,
main_program=self._program.train_program,
predicate=fn)
else:
raise NotImplementedError()
self._saver = Saver(
self._model_dir,
F.Executor(F.cuda_places()[0]),
program=self._program.train_program,
max_ckpt_to_keep=self._max_ckpt)
if self._saver.last_ckpt is not None:
self._state = self._saver.restore()
def freeze(self):
if self._loss is None:
log.debug('will not freeze a program without loss')
return
if isinstance(self._program.train_program, F.compiler.CompiledProgram):
log.debug('program has already been built')
return
exec_strategy = F.ExecutionStrategy()
exec_strategy.num_threads = 4 #2 for fp32 4 for fp16
exec_strategy.use_experimental_executor = True
exec_strategy.num_iteration_per_drop_scope = 10 #important shit
build_strategy = F.BuildStrategy()
build_strategy.remove_unnecessary_lock = False
#build_strategy.fuse_broadcast_ops = True
build_strategy.num_trainers = distribution.status.num_replica
build_strategy.trainer_id = distribution.status.replica_id
build_strategy.memory_optimize = True
log.info('replica id %d of %d' % (distribution.status.replica_id,
distribution.status.num_replica))
program = F.CompiledProgram(
self._program.train_program).with_data_parallel(
loss_name=self._loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
self._program = ProgramPair(
train_program=program,
startup_program=self._program.startup_program)
def __enter__(self):
log.debug('freezing program')
self.freeze()
log.debug('done freezing')
log.info('********** Start Loop ************')
# TODO init
self.result = None
for h in self._hooks:
log.debug('train loop has hook %s' % h)
h.before_train()
return self
def run(self, fetch_list=[], *args, **kwargs):
#log.debug('Executor running step %d' % self._state.gstep)
if self._hooks:
fetch_list = [fetch_list]
for h in self._hooks:
#log.debug('calling hook.before_run %s' % h)
fetch = h.before_run(self._state)
fetch_list.append(fetch)
fetch_list_len = map(len, fetch_list)
fetch_list, schema = util.flatten(fetch_list)
fetch_list = [
f.name if not isinstance(f, six.string_types) else f
for f in fetch_list
]
#if len(set(fetch_list)) != len(fetch_list):
# log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list)
res = self._exe.run(self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
res = [self.merge_result(r) for r in res]
#log.debug(res)
res = util.unflatten(res, schema)
ret, res = res[0], res[1:]
for r, h in zip(res, self._hooks):
#log.debug('calling hook.after_run')
h.after_run(r, self._state)
if any(map(lambda i: i.should_stop(self._state), self._hooks)):
raise StopException('hook call stop')
else:
ret = self._exe.run(self._program.train_program,
fetch_list=fetch_list,
*args,
**kwargs)
self._state = self._state.next()
return ret
def __exit__(self, err_type, err_value, trace):
if (err_type is None) or isinstance(err_value, (
F.core.EOFException, StopException, KeyboardInterrupt)):
try:
log.info('********** Stop Loop ************')
self.result = []
for h in self._hooks:
self.result.append(h.after_train())
except Exception as e:
log.exception('error occur after loop %s' % repr(e))
else:
log.info('********** Interupt Loop ************')
log.exception('error occur during loop %s: %s' %
(err_type, err_value))
def merge_result(self, ls):
dev_count = len(self._program.train_program._places) if isinstance(
self._program.train_program, F.compiler.CompiledProgram) else 1
if dev_count == 1:
return ls
else:
shape = (-1, ls.shape[0] // dev_count) + ls.shape[1:]
ret = np.reshape(ls, shape).mean(axis=0)
return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import itertools
import six
import inspect
from collections import namedtuple
from contextlib import contextmanager
from six.moves import zip, map
import logging
from time import time
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.types import RunMode, StopException, SummaryRecord, StopException, ModelSpec, InferenceSpec, ProgramPair, RunConfig
from propeller.paddle import summary, collection
from propeller.paddle.data.functional import Dataset
from propeller.paddle.train import distribution
from propeller.train.model import Model
from propeller.paddle.train.monitored_executor import Saver
from propeller.paddle.train import hooks, metrics
from propeller.paddle.train.monitored_executor import MonitoredExecutor
log = logging.getLogger(__name__)
__all__ = ['train_and_eval', 'Learner']
def get_summary_writer(path):
summary_writer = None
try:
from tensorboardX import SummaryWriter
if distribution.status.is_master:
summary_writer = SummaryWriter(os.path.join(path))
except ImportError:
log.warning('tensorboardX not installed, will not log to tensorboard')
return summary_writer
def log_eval_result(name, eval_result, swriter, state):
log.debug(eval_result)
printable = []
for n, val in six.iteritems(eval_result):
assert val.shape == (), 'metrics eval use float'
printable.append('{}\t{}'.format(n, val))
if swriter is not None:
swriter.add_scalar(n, val, state.gstep)
log.debug('write to tensorboard %s' % swriter.logdir)
if len(printable):
log.info('*** eval res: %10s ***' % name)
for p in printable:
log.info(p)
log.info('******************************')
def build_net(model_fn, features, mode, params, run_config):
model_spec = model_fn(
features=features, mode=mode, params=params, run_config=run_config)
if mode == RunMode.TRAIN:
if not isinstance(model_spec.loss, F.framework.Variable):
raise ValueError('model_spec.metrics should be Variable, got %s' %
repr(model_spec.loss))
if not (model_spec.loss.shape == () or model_spec.loss.shape == (1, )):
raise ValueError('expect scarlar loss, got %s' %
repr(model_spec.loss.shape))
model_spec.loss.persistable = True
elif mode == RunMode.EVAL:
if not isinstance(model_spec.metrics, dict):
raise ValueError('model_spec.metrics should be dict, got %s' %
repr(model_spec.metrics))
elif mode == RunMode.PREDICT:
if not isinstance(model_spec.predictions, (list, tuple)):
raise ValueError('model_spec.predictions shuold be list, got %s' %
repr(model_spec.predictions))
else:
raise ValueError('unkonw mode %s' % mode)
return model_spec
class Learner(object):
def __init__(self,
model_class_or_model_fn,
run_config,
params=None,
warm_start_setting=None):
'''
model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
1. subclass of propeller.train.Model which implements:
1. \_\_init\_\_ (hyper_param, mode, run_config)
2. forward (features) => (prediction)
3. backword (loss) => None
4. loss (predictoin) => (loss)
5. metrics (optional) (prediction) => (dict of propeller.Metrics)
2. a model_fn takes following args:
1. features
2. param
3. mode
4. run_config(optional)
and returns a `propeller.ModelSpec`
params: any python object, will pass to your `model_fn` or `propeller.train.Model`
run_config (propeller.RunConfig): run_config.max_steps should not be None.
warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
'''
if run_config.model_dir is None:
raise ValueError('model_dir should specified in run_config')
if issubclass(model_class_or_model_fn, Model):
def model_fn(features, mode, params, run_config):
if mode != RunMode.PREDICT:
fea, label = features[:-1], features[-1]
else:
fea = features
model = model_class_or_model_fn(
params, mode, run_config=run_config)
pred = model.forward(fea)
if isinstance(pred, F.framework.Variable):
prediction = [pred]
else:
prediction = pred
if mode == RunMode.TRAIN:
loss = model.loss(pred, label)
model.backward(loss)
return ModelSpec(
loss=loss, predictions=prediction, mode=mode)
elif mode == RunMode.EVAL:
loss = model.loss(pred, label)
me = model.metrics(pred, label)
inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
if 'loss' not in me:
me['loss'] = metrics.Mean(loss)
return ModelSpec(
loss=loss,
predictions=prediction,
metrics=me,
mode=mode,
inference_spec=inf_spec)
elif mode == RunMode.PREDICT:
return ModelSpec(predictions=prediction, mode=mode)
else:
raise RuntimeError('unknown run mode %s' % mode)
elif inspect.isfunction(model_class_or_model_fn):
model_fn = model_class_or_model_fn
else:
raise ValueError('unknown model %s' % model_class_or_model_fn)
self.model_fn = model_fn
self.params = params
self.run_config = run_config
self.warm_start_setting = warm_start_setting
def build_for_train(self, train_dataset):
train_dataset.name = 'train'
train_program = F.Program()
startup_prog = F.Program()
with F.program_guard(train_program, startup_prog):
with F.unique_name.guard():
with collection.Collections() as collections:
log.info('Building Train Graph...')
fea = train_dataset.features()
model_spec = build_net(self.model_fn, fea, RunMode.TRAIN,
self.params, self.run_config)
log.info('Building Train Graph: Done')
scalars = collections.get(collection.Key.SUMMARY_SCALAR)
histograms = collections.get(collection.Key.SUMMARY_HISTOGRAM)
skip_optimize_ops = collections.get(
collection.Key.SKIP_OPTIMIZE)
skip_opt = set()
if skip_optimize_ops is not None:
skip_opt |= set(skip_optimize_ops)
if scalars is not None:
skip_opt |= {t for _, t in scalars}
if histograms is not None:
skip_opt |= {t for _, t in histograms}
skip_opt = list(skip_opt)
log.info(
'Train with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
% (repr(self.run_config), repr(self.params), repr(model_spec)))
summary_record = SummaryRecord(
scalar=collections.get(collection.Key.SUMMARY_SCALAR),
histogram=collections.get(collection.Key.SUMMARY_HISTOGRAM), )
return ProgramPair(
train_program=train_program,
startup_program=startup_prog), model_spec, summary_record
def build_for_eval(self, ds):
ds.name = 'eval'
program = F.Program()
startup_prog = F.Program()
with F.program_guard(program, startup_prog):
#share var with Train net
with F.unique_name.guard():
log.info('Building Eval Graph')
fea = ds.features()
model_spec = build_net(self.model_fn, fea, RunMode.EVAL,
self.params, self.run_config)
log.info('Done')
program = program.clone(for_test=True)
log.info(
'Eval with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
% (repr(self.run_config), repr(self.params), repr(model_spec)))
return ProgramPair(
train_program=program, startup_program=startup_prog), model_spec
def build_for_predict(self, ds):
ds.name = 'predict'
program = F.Program()
startup_prog = F.Program()
with F.program_guard(program, startup_prog):
#share var with Train net
with F.unique_name.guard():
log.info('Building Predict Graph')
fea = ds.features()
model_spec = build_net(self.model_fn, fea, RunMode.PREDICT,
self.params, self.run_config)
log.info('Done')
program = program.clone(for_test=True)
log.info(
'Predict with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
% (repr(self.run_config), repr(self.params), repr(model_spec)))
return ProgramPair(
train_program=program, startup_program=startup_prog), model_spec
def train(self, train_ds, train_hooks=[]):
if not isinstance(train_ds, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(train_ds))
train_program, model_spec, summary_record = self.build_for_train(
train_ds)
train_run_hooks = [
hooks.StopAtStepHook(self.run_config.max_steps,
self.run_config.run_steps),
hooks.LoggingHook(
model_spec.loss,
summary_record=summary_record,
summary_writer=get_summary_writer(
os.path.join(self.run_config.model_dir, 'train_history')),
per_step=self.run_config.log_steps,
skip_step=self.run_config.skip_steps),
]
train_run_hooks.extend(train_hooks)
train_executor = F.Executor(F.cuda_places()[0])
mon_exe = MonitoredExecutor(
train_executor,
train_program,
loss=model_spec.loss,
run_config=self.run_config,
run_hooks=train_run_hooks,
warm_start_setting=self.warm_start_setting)
distribution.init_distribuition_env(
train_program) #only initialize distribute training with
mon_exe.init_or_restore_variables()
if distribution.status.is_master:
mon_exe._hooks.append(
hooks.CheckpointSaverHook(
mon_exe._saver,
per_step=mon_exe._save_steps,
skip_step=mon_exe._skip_steps))
try:
with mon_exe:
for data in train_ds.start():
mon_exe.run(feed=data)
except (StopException, F.core.EOFException) as e:
pass
return mon_exe.result
def evaluate(self, eval_dataset, eval_hooks=[]):
if not isinstance(eval_dataset, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(eval_dataset))
program, model_spec = self.build_for_eval(eval_dataset)
single_card_place = F.cuda_places()[0]
eval_executor = F.Executor(single_card_place)
eval_hooks = [
hooks.StopAtStepHook(self.run_config.eval_max_steps,
self.run_config.eval_max_steps),
hooks.EvalHook(model_spec.metrics, )
]
mon_exe = MonitoredExecutor(
eval_executor,
program,
run_config=self.run_config,
run_hooks=eval_hooks)
mon_exe.init_or_restore_variables()
try:
with mon_exe:
for data in eval_dataset.start(places=[single_card_place]):
mon_exe.run(feed=data)
except (StopException, F.core.EOFException) as e:
pass
_, eval_result = mon_exe.result
summary_writer = get_summary_writer(
os.path.join(self.run_config.model_dir, 'eval_history'))
log_eval_result('eval', eval_result, summary_writer, mon_exe.state)
return mon_exe.result
def predict(self, predict_dataset, ckpt=None, steps=-1, split_batch=True):
'''
Perform predictoin
will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode
Args:
infer_dataset (propeller.data.Dataset): should not `shuffle` or `repeat`
steps (int): steps to predict, if -1 is specifed, will stop when `StopException` is raised in `infer_dataset`
split_batch (bool): if True, prediction of each example in a batch is returned.
Yields:
Evaluated values of predictions tensors.
'''
if not isinstance(predict_dataset, Dataset):
raise ValueError('expect dataset to be instance of Dataset, got %s'
% repr(predict_dataset))
program, model_spec = self.build_for_predict(predict_dataset)
single_card_place = F.cuda_places()[0]
executor = F.Executor(single_card_place)
pred_run_config = RunConfig(
run_steps=steps if steps == -1 else None,
model_dir=self.run_config.model_dir)
mon_exe = MonitoredExecutor(
executor,
program,
run_config=pred_run_config, )
mon_exe.init_or_restore_variables()
try:
with mon_exe:
log.info('Runining predict from dir: %s' % repr(mon_exe.state))
single_card_place = F.cuda_places()[0]
for data in predict_dataset.start(places=[single_card_place]):
res = mon_exe.run(fetch_list=model_spec.predictions,
feed=data)
if split_batch:
res = map(lambda i: i.tolist(), res)
res = zip(*res) # transpose
for r in res:
yield r
else:
yield list(map(lambda i: i.tolist(), res))
except (StopException, F.core.EOFException) as e:
pass
def train_and_eval(_shit=None,
model_class_or_model_fn=None,
params=None,
run_config=None,
train_dataset=None,
eval_dataset=None,
warm_start_setting=None,
train_hooks=[],
eval_hooks=[],
exporters=[]):
'''
Perform train and evaluate procesure.
will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode
Args:
model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
1. subclass of propeller.train.Model which implements:
1. \_\_init\_\_ (hyper_param, mode, run_config)
2. forward (features) => (prediction)
3. backword (loss) => None
4. loss (predictoin) => (loss)
5. metrics (optional) (prediction) => (dict of propeller.Metrics)
2. a model_fn takes following args:
1. features
2. param
3. mode
4. run_config(optional)
and returns a `propeller.ModelSpec`
params: any python object, will pass to your `model_fn` or `propeller.train.Model`
run_config (propeller.RunConfig): run_config.max_steps should not be None.
train_dataset (propeller.paddle.data.Dataset): training will stop if global_step > run_config.max_steps.
eval_dataset (propeller.paddle.data.Dataset|dict): Optional, if Dict of propeller.data.Dataset were specified, will perform evluatation on every evaluation sets and report results.
warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
train_hooks (list of propeller.paddle.train.RunHook): Optional.
eval_hooks (list of propeller.paddle.train.RunHook): Optional.
exporters (list of propeller.paddle.train.Exporter): Optional.
'''
if _shit is not None:
raise ValueError('specify keyword args to this function')
if model_class_or_model_fn is None or params is None or run_config is None or train_dataset is None:
raise ValueError(
'some argument is None: model_class_or_model_fn:%s params:%s run_config:%s train_dataset:%s'
% (model_class_or_model_fn, params, run_config, train_dataset))
#init distribution env if envvir PROPELLER_DISCONFIG is set
if train_dataset is None:
raise ValueError('train dataset not specified')
if eval_dataset is None:
raise ValueError('eval dataset not specifed')
if not isinstance(eval_dataset, (dict, Dataset)):
raise ValueError(
'Eval dataset should be propeller.Dataset of a list of that, got: %s'
% eval_dataset)
if isinstance(eval_dataset, Dataset):
eval_dataset = {'eval': eval_dataset}
ds_list = list(eval_dataset.values())
for ds in ds_list:
ds.name = 'eval'
first = ds_list[0]
for d in ds_list[1:]:
if not first.__eq__(d):
raise ValueError(
'eval dataset has different output_shapes or types: %s' %
repr(ds_list))
est = Learner(
model_class_or_model_fn,
run_config,
params,
warm_start_setting=warm_start_setting)
class EvalHookOnTrainLoop(hooks.RunHook):
def __init__(self):
self.program, self.model_spec = est.build_for_eval(
list(eval_dataset.values())[
0]) #eval_datasets must have same output shapes
self.summary_writers = {
ds_name: get_summary_writer(
os.path.join(
os.path.join(run_config.model_dir, 'eval_history'),
ds_name))
for ds_name in eval_dataset
}
def after_run(self, _, state):
if state.step > run_config.skip_steps and state.gstep % run_config.eval_steps == 0:
eval_results = {}
for name, ds in six.iteritems(eval_dataset):
ehooks = [
hooks.StopAtStepHook(est.run_config.eval_max_steps,
est.run_config.eval_max_steps),
hooks.EvalHook(
self.model_spec.metrics,
summary_writer=self.summary_writers[name], )
]
single_card_place = F.cuda_places()[0]
eval_executor = F.Executor(single_card_place)
mon_exe = MonitoredExecutor(
eval_executor,
self.program,
run_config=est.run_config,
run_hooks=ehooks + eval_hooks)
try:
with mon_exe:
for data in ds.start(places=[single_card_place]):
mon_exe.run(feed=data)
except (StopException, F.core.EOFException) as e:
pass
hook_results = mon_exe.result
eval_res = hook_results[
1] # hook_results: [StopAtStepHook, EvalHook, ...]
eval_results[name] = eval_res
log_eval_result(name, eval_res, self.summary_writers[name],
state)
for exporter in exporters:
exporter.export(eval_executor, self.program,
self.model_spec, eval_results, state)
else:
eval_results = {}
return eval_results
if distribution.status.is_master:
train_hooks.append(EvalHookOnTrainLoop())
res = est.train(train_dataset, train_hooks=train_hooks)
return res
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import asyncio
import threading
import math
import zmq
import zmq.asyncio
import numpy as np
from propeller import log
import propeller.service.utils as serv_utils
class InferenceBaseClient(object):
def __init__(self, address):
self.context = zmq.Context()
self.address = address
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(address)
log.info("Connecting to server... %s" % address)
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
request = serv_utils.nparray_list_serialize(args)
self.socket.send(request)
reply = self.socket.recv()
ret = serv_utils.nparray_list_deserialize(reply)
return ret
class InferenceClient(InferenceBaseClient):
def __init__(self, address, batch_size=128, num_coroutine=10, timeout=10.):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
context = zmq.asyncio.Context()
self.socket_pool = [
context.socket(zmq.REQ) for _ in range(num_coroutine)
]
log.info("Connecting to server... %s" % address)
for socket in self.socket_pool:
socket.connect(address)
self.num_coroutine = num_coroutine
self.batch_size = batch_size
self.timeout = int(timeout * 1000)
#yapf: disable
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
num_tasks = math.ceil(1. * args[0].shape[0] / self.batch_size)
rets = [None] * num_tasks
async def get(coroutine_idx=0, num_coroutine=1):
socket = self.socket_pool[coroutine_idx]
while coroutine_idx < num_tasks:
begin = coroutine_idx * self.batch_size
end = (coroutine_idx + 1) * self.batch_size
arr_list = [arg[begin:end] for arg in args]
request = serv_utils.nparray_list_serialize(arr_list)
try:
await socket.send(request)
await socket.poll(self.timeout, zmq.POLLIN)
reply = await socket.recv(zmq.NOBLOCK)
ret = serv_utils.nparray_list_deserialize(reply)
except Exception as e:
log.exception(e)
ret = None
rets[coroutine_idx] = ret
coroutine_idx += num_coroutine
futures = [
get(i, self.num_coroutine) for i in range(self.num_coroutine)
]
self.loop.run_until_complete(asyncio.wait(futures))
for r in rets:
if r is None:
raise RuntimeError('Client call failed')
return [np.concatenate(col, 0) for col in zip(*rets)]
#yapf: enable
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package interface;
service Inference {
rpc Infer(Slots) returns (Slots){}
}
message Slots {
repeated Slot slots = 1;
}
message Slot {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
}
Type type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
bytes data = 3;
}
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: interface.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='interface.proto',
package='interface',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n\x0finterface.proto\x12\tinterface\"\'\n\x05Slots\x12\x1e\n\x05slots\x18\x01 \x03(\x0b\x32\x0f.interface.Slot\"\xb8\x01\n\x04Slot\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.interface.Slot.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"p\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x32:\n\tInference\x12-\n\x05Infer\x12\x10.interface.Slots\x1a\x10.interface.Slots\"\x00\x62\x06proto3'
))
_SLOT_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='interface.Slot.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16',
index=1,
number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32',
index=2,
number=2,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64',
index=3,
number=3,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6, serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T',
index=7,
number=19,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8',
index=8,
number=20,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8',
index=9,
number=21,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=144,
serialized_end=256, )
_sym_db.RegisterEnumDescriptor(_SLOT_TYPE)
_SLOTS = _descriptor.Descriptor(
name='Slots',
full_name='interface.Slots',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='slots',
full_name='interface.Slots.slots',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=30,
serialized_end=69, )
_SLOT = _descriptor.Descriptor(
name='Slot',
full_name='interface.Slot',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='interface.Slot.type',
index=0,
number=1,
type=14,
cpp_type=8,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='dims',
full_name='interface.Slot.dims',
index=1,
number=2,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data',
full_name='interface.Slot.data',
index=2,
number=3,
type=12,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b(""),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[_SLOT_TYPE, ],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=72,
serialized_end=256, )
_SLOTS.fields_by_name['slots'].message_type = _SLOT
_SLOT.fields_by_name['type'].enum_type = _SLOT_TYPE
_SLOT_TYPE.containing_type = _SLOT
DESCRIPTOR.message_types_by_name['Slots'] = _SLOTS
DESCRIPTOR.message_types_by_name['Slot'] = _SLOT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Slots = _reflection.GeneratedProtocolMessageType(
'Slots',
(_message.Message, ),
{
'DESCRIPTOR': _SLOTS,
'__module__': 'interface_pb2'
# @@protoc_insertion_point(class_scope:interface.Slots)
})
_sym_db.RegisterMessage(Slots)
Slot = _reflection.GeneratedProtocolMessageType(
'Slot',
(_message.Message, ),
{
'DESCRIPTOR': _SLOT,
'__module__': 'interface_pb2'
# @@protoc_insertion_point(class_scope:interface.Slot)
})
_sym_db.RegisterMessage(Slot)
_INFERENCE = _descriptor.ServiceDescriptor(
name='Inference',
full_name='interface.Inference',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=258,
serialized_end=316,
methods=[
_descriptor.MethodDescriptor(
name='Infer',
full_name='interface.Inference.Infer',
index=0,
containing_service=None,
input_type=_SLOTS,
output_type=_SLOTS,
serialized_options=None, ),
])
_sym_db.RegisterServiceDescriptor(_INFERENCE)
DESCRIPTOR.services_by_name['Inference'] = _INFERENCE
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import os
import logging
import six
from time import sleep, time
import multiprocessing
import zmq
""" Never Never Never import paddle.fluid in main process, or any module would import fluid.
"""
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
class Predictor(object):
def __init__(self, model_dir, device_idx=0):
import paddle.fluid as F
log.debug('create predictor on card %d' % device_idx)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, device_idx)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
try:
log.debug("run_worker %s" % device_idx)
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv(
"CUDA_VISIBLE_DEVICES").split(",")[device_idx]
log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
import paddle.fluid as F
from propeller.service import interface_pb2
import propeller.service.utils as serv_utils
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.connect(endpoint)
#socket.bind(endpoint)
log.debug("Predictor building %s" % device_idx)
predictor = Predictor(model_dir, 0)
log.debug("Predictor %s" % device_idx)
except Exception as e:
log.exception(e)
while True:
# Wait for next request from client
try:
message = socket.recv()
log.debug("get message %s" % device_idx)
slots = interface_pb2.Slots()
slots.ParseFromString(message)
pts = [serv_utils.slot_to_paddlearray(s) for s in slots.slots]
ret = predictor(pts)
slots = interface_pb2.Slots(
slots=[serv_utils.paddlearray_to_slot(r) for r in ret])
socket.send(slots.SerializeToString())
except Exception as e:
log.exception(e)
socket.send(e.message)
class InferencePredictor(object):
def __init__(self, backend_addr, model_dir, n_devices=1):
self.backend_addr = backend_addr
self.model_dir = model_dir
self.n_devices = n_devices
self.children = []
def start(self):
for device_idx in range(self.n_devices):
p = multiprocessing.Process(
target=run_worker,
args=(self.model_dir, device_idx, self.backend_addr))
p.start()
self.children.append(p)
return self
def join(self):
for p in self.children:
p.join()
def term(self):
for p in self.children:
log.debug("terminating children %s" % repr(p))
p.terminate()
class InferenceProxy(object):
def __init__(self):
self.backend = None
self.frontend = None
def listen(self, frontend_addr, backend_addr):
log.info("InferenceProxy starting...")
try:
context = zmq.Context(1)
# Socket facing clients
self.frontend = context.socket(zmq.ROUTER)
self.frontend.bind(frontend_addr)
# Socket facing services
self.backend = context.socket(zmq.DEALER)
self.backend.bind(backend_addr)
log.info("Queue init done")
zmq.device(zmq.QUEUE, self.frontend, self.backend)
except Exception as e:
log.exception(e)
log.info("Bringing down zmq device")
finally:
log.debug('terminating proxy')
if self.frontend is not None:
self.frontend.close()
if self.backend is not None:
self.backend.close()
context.term()
class InferenceServer(object):
def __init__(self, model_dir, n_devices):
self.model_dir = model_dir
self.n_devices = n_devices
def listen(self, port):
frontend_addr = "tcp://*:%s" % port
backend_addr = "ipc://backend.ipc"
predictor = InferencePredictor(backend_addr, self.model_dir,
self.n_devices).start()
try:
proxy = InferenceProxy()
proxy.listen(frontend_addr, backend_addr)
predictor.join()
except KeyboardInterrupt:
log.debug('terminating server')
predictor.term()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import struct
from propeller.service import interface_pb2
def slot_to_numpy(slot):
if slot.type == interface_pb2.Slot.FP32:
dtype = np.float32
type_str = 'f'
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = np.int32
elif slot.type == interface_pb2.Slot.INT64:
dtype = np.int64
type_str = 'q'
else:
raise RuntimeError('know type %s' % slot.type)
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
shape = slot.dims
ret = np.array(arr, dtype=dtype).reshape(shape)
return ret
def numpy_to_slot(arr):
if arr.dtype == np.float32:
dtype = interface_pb2.Slot.FP32
elif arr.dtype == np.int32:
dtype = interface_pb2.Slot.INT32
elif arr.dtype == np.int64:
dtype = interface_pb2.Slot.INT64
else:
raise RuntimeError('know type %s' % arr.dtype)
pb = interface_pb2.Slot(
type=dtype, dims=list(arr.shape), data=arr.tobytes())
return pb
def slot_to_paddlearray(slot):
import paddle.fluid.core as core
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
import paddle.fluid.core as core
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
def nparray_list_serialize(arr_list):
slot_list = [numpy_to_slot(arr) for arr in arr_list]
slots = interface_pb2.Slots(slots=slot_list)
return slots.SerializeToString()
def nparray_list_deserialize(string):
slots = interface_pb2.Slots()
slots.ParseFromString(string)
return [slot_to_numpy(slot) for slot in slots.slots]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import struct
import logging
import argparse
import numpy as np
import collections
from distutils import dir_util
import pickle
#from utils import print_arguments
import paddle.fluid as F
from paddle.fluid.proto import framework_pb2
log = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
console = logging.StreamHandler()
console.setFormatter(formatter)
log.addHandler(console)
log.setLevel(logging.DEBUG)
def gen_arr(data, dtype):
num = len(data) // struct.calcsize(dtype)
arr = struct.unpack('%d%s' % (num, dtype), data)
return arr
def parse(filename):
with open(filename, 'rb') as f:
read = lambda fmt: struct.unpack(fmt, f.read(struct.calcsize(fmt)))
_, = read('I') # version
lodsize, = read('Q')
if lodsize != 0:
log.warning('shit, it is LOD tensor!!! skipped!!')
return None
_, = read('I') # version
pbsize, = read('i')
data = f.read(pbsize)
proto = framework_pb2.VarType.TensorDesc()
proto.ParseFromString(data)
log.info('type: [%s] dim %s' % (proto.data_type, proto.dims))
if proto.data_type == framework_pb2.VarType.FP32:
arr = np.array(
gen_arr(f.read(), 'f'), dtype=np.float32).reshape(proto.dims)
elif proto.data_type == framework_pb2.VarType.INT64:
arr = np.array(
gen_arr(f.read(), 'q'), dtype=np.int64).reshape(proto.dims)
elif proto.data_type == framework_pb2.VarType.INT32:
arr = np.array(
gen_arr(f.read(), 'i'), dtype=np.int32).reshape(proto.dims)
elif proto.data_type == framework_pb2.VarType.INT8:
arr = np.array(
gen_arr(f.read(), 'B'), dtype=np.int8).reshape(proto.dims)
else:
raise RuntimeError('Unknown dtype %s' % proto.data_type)
return arr
def show(arr):
print(repr(arr))
def dump(arr, path):
path = os.path.join(args.to, path)
log.info('dump to %s' % path)
try:
os.makedirs(os.path.dirname(path))
except FileExistsError:
pass
pickle.dump(arr, open(path, 'wb'), protocol=4)
def list_dir(dir_or_file):
if os.path.isfile(dir_or_file):
return [dir_or_file]
else:
return [
os.path.join(i, kk) for i, _, k in os.walk(dir_or_file) for kk in k
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('mode', choices=['show', 'dump'], type=str)
parser.add_argument('file_or_dir', type=str)
parser.add_argument('-t', "--to", type=str, default=None)
parser.add_argument('-v', "--verbose", action='store_true')
args = parser.parse_args()
files = list_dir(args.file_or_dir)
parsed_arr = map(parse, files)
if args.mode == 'show':
for arr in parsed_arr:
if arr is not None:
show(arr)
elif args.mode == 'dump':
if args.to is None:
raise ValueError('--to dir_name not specified')
for arr, path in zip(parsed_arr, files):
if arr is not None:
dump(arr, path.replace(args.file_or_dir, ''))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import sys
import os
import argparse
import logging
import logging.handlers
from propeller.service.server import InferenceServer
from propeller import log
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_dir', type=str, required=True)
parser.add_argument('-p', '--port', type=int, required=True)
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()
if args.verbose:
log.setLevel(logging.DEBUG)
n_devices = len(os.getenv("CUDA_VISIBLE_DEVICES").split(","))
server = InferenceServer(args.model_dir, n_devices)
log.info('propeller server listent on port %d' % args.port)
server.listen(args.port)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import six
import logging
import os
import itertools
import json
import abc
import numpy as np
@six.add_metaclass(abc.ABCMeta)
class Model():
def __init__(self, config, mode):
"""
Args:
config (dict): hyper param
mode (propeller.RunMode): will creat `TRAIN` and `EVAL` model in propeller.train_and_eval
"""
self.mode = mode
@abc.abstractmethod
def forward(self, features):
"""
Args:
features (list of Tensor): depends on your Dataset.output_shapes
Returns:
return (Tensor):
"""
pass
@abc.abstractmethod
def loss(self, predictions, label):
"""
Args:
predictions (Tensor): result of `self.forward`
label (Tensor): depends on your Dataset.output_shapes
Returns:
return (paddle scalar): loss
"""
pass
@abc.abstractmethod
def backward(self, loss):
"""
Call in TRAIN mode
Args:
loss (Tensor): result of `self.loss`
Returns:
None
"""
pass
@abc.abstractmethod
def metrics(self, predictions, label):
"""
Call in EVAL mode
Args:
predictions (Tensor): result of `self.forward`
label (Tensor): depends on your Dataset.output_shapes
Returns:
(dict): k-v map like: {"metrics_name": propeller.Metrics }
"""
return {}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import json
from collections import namedtuple
class RunMode(object):
TRAIN = 1
PREDICT = 2
EVAL = 3
class HParams(object):
def __init__(self, **kwargs):
for k, v in kwargs.items():
self.__dict__[k] = v
def __contains__(self, key):
return key in self.__dict__
def __getitem__(self, key):
if key not in self.__dict__:
raise ValueError('key(%s) not in HParams.' % key)
return self.__dict__[key]
def __repr__(self):
return repr(self.to_dict())
def __setitem__(self, key, val):
self.__dict__[key] = val
@staticmethod
def from_json(self, json_str):
d = json.loads(json_str)
if type(d) != dict:
raise ValueError('json object must be dict.')
return HParams.from_dict(d)
def get(self, key, default=None):
return self.__dict__.get(key, default)
@staticmethod
def from_dict(self, d):
if type(d) != dict:
raise ValueError('input must be dict.')
hp = HParams(**d)
return hp
def to_json(self):
return json.dumps(self.__dict__)
def to_dict(self):
return self.__dict__
def join(self, other):
if not isinstance(other, HParams):
raise ValueError('input must be HParams instance.')
self.__dict__.update(**other.__dict__)
return self
SummaryRecord = namedtuple('SummaryRecord', ['scalar', 'histogram'])
WarmStartSetting = namedtuple('WarmStartSetting', ['predicate_fn', 'from_dir'])
RunConfig = namedtuple('RunConfig', [
'model_dir', 'run_steps', 'max_steps', 'save_steps', 'eval_steps',
'eval_max_steps', 'skip_steps', 'log_steps', 'max_ckpt', 'shit'
])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
ProgramPair = namedtuple('ProgramPair', ['train_program', 'startup_program'])
InferenceSpec = namedtuple('InferenceSpec', ['inputs', 'outputs'])
ModelSpec = namedtuple('ModelSpec', [
'loss',
'predictions',
'metrics',
'mode',
'inference_spec',
])
ModelSpec.__new__.__defaults__ = (None, ) * len(ModelSpec._fields)
class StopException(Exception):
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import six
import re
import json
import argparse
import itertools
import logging
from functools import reduce
from propeller.types import RunConfig
from propeller.types import HParams
log = logging.getLogger(__name__)
def ArgumentParser(name):
parser = argparse.ArgumentParser('propeller model')
parser.add_argument('--run_config', type=str, default='')
parser.add_argument(
'--hparam', type=str, nargs='*', action='append', default=[['']])
return parser
def _get_dict_from_environ_or_json_or_file(args, env_name):
if args == '':
return None
if args is None:
s = os.environ.get(env_name)
else:
s = args
if os.path.exists(s):
s = open(s).read()
if isinstance(s, six.string_types):
try:
r = eval(s)
except SyntaxError as e:
raise ValueError('json parse error: %s \n>Got json: %s' %
(repr(e), s))
return r
else:
return s #None
def parse_file(filename):
d = _get_dict_from_environ_or_json_or_file(filename, None)
if d is None:
raise ValueError('file(%s) not found' % filename)
return d
def parse_runconfig(args=None):
d = _get_dict_from_environ_or_json_or_file(args.run_config,
'PROPELLER_RUNCONFIG')
if d is None:
raise ValueError('run_config not found')
return RunConfig(**d)
def parse_hparam(args=None):
if args is not None:
hparam_strs = reduce(list.__add__, args.hparam)
else:
hparam_strs = [None]
hparams = [
_get_dict_from_environ_or_json_or_file(hp, 'PROPELLER_HPARAMS')
for hp in hparam_strs
]
hparams = [HParams(**h) for h in hparams if h is not None]
if len(hparams) is None:
raise ValueError('hparam not found')
hparam = reduce(lambda x, y: x.join(y), hparams)
return hparam
def flatten(s):
assert is_struture(s)
schema = [len(ss) for ss in s]
flt = list(itertools.chain(*s))
return flt, schema
def unflatten(structure, schema):
start = 0
res = []
for _range in schema:
res.append(structure[start:start + _range])
start += _range
return res
def is_struture(s):
return isinstance(s, list) or isinstance(s, tuple)
def map_structure(func, s):
if isinstance(s, list) or isinstance(s, tuple):
return [map_structure(func, ss) for ss in s]
elif isinstance(s, dict):
return {k: map_structure(func, v) for k, v in six.iteritems(s)}
else:
return func(s)
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -252,12 +252,6 @@ class BaseReader(object):
phase=None):
examples = self._read_tsv(input_file)
if phase == 'train':
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
examples = examples[trainer_id: (len(examples) //trainer_num) * trainer_num : trainer_num]
log.info('apply sharding %d/%d' % (trainer_id, trainer_num))
def wrapper():
all_dev_batches = []
for epoch_index in range(epoch):
......
nltk==3.4
numpy==1.14.5
pyzmq==18.0.2
scikit-learn==0.20.3
scipy==1.2.1
six==1.11.0
sklearn==0.0
......@@ -92,6 +92,8 @@ def main(args):
num_train_examples = reader.get_num_examples(args.train_set)
if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count
else:
......@@ -376,11 +378,12 @@ def main(args):
def evaluate_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
# evaluate dev set
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for ds in args.dev_set.split(','):
test_pyreader.decorate_tensor_provider(
reader.data_generator(
ds,
batch_size=args.predict_batch_size,
batch_size=batch_size,
epoch=1,
dev_count=1,
shuffle=False))
......@@ -403,12 +406,13 @@ def predict_wrapper(args, reader, exe, test_prog, test_pyreader, graph_vars,
test_sets = args.test_set.split(',')
save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs)
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider(
reader.data_generator(
test_f,
batch_size=args.predict_batch_size,
batch_size=batch_size,
epoch=1,
dev_count=1,
shuffle=False))
......
......@@ -95,6 +95,8 @@ def main(args):
num_train_examples = reader.get_num_examples("train")
if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count
else:
......
......@@ -85,6 +85,9 @@ def main(args):
num_train_examples = reader.get_num_examples(args.train_set)
if args.in_tokens:
if args.batch_size < args.max_seq_len:
raise ValueError('if in_tokens=True, batch_size should greater than max_sqelen, got batch_size:%d seqlen:%d' % (args.batch_size, args.max_seq_len))
max_train_steps = args.epoch * num_train_examples // (
args.batch_size // args.max_seq_len) // dev_count
else:
......@@ -297,11 +300,12 @@ def main(args):
def evaluate_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
epoch, steps):
# evaluate dev set
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for ds in args.dev_set.split(','): #single card eval
test_pyreader.decorate_tensor_provider(
reader.data_generator(
ds,
batch_size=args.predict_batch_size,
batch_size=batch_size,
epoch=1,
dev_count=1,
shuffle=False))
......@@ -318,10 +322,11 @@ def predict_wrapper(reader, exe, test_prog, test_pyreader, graph_vars,
save_dirs = args.test_save.split(',')
assert len(test_sets) == len(save_dirs), 'number of test_sets & test_save not match, got %d vs %d' % (len(test_sets), len(save_dirs))
batch_size = args.batch_size if args.predict_batch_size is None else args.predict_batch_size
for test_f, save_f in zip(test_sets, save_dirs):
test_pyreader.decorate_tensor_provider(reader.data_generator(
test_f,
batch_size=args.predict_batch_size,
batch_size=batch_size,
epoch=1,
dev_count=1,
shuffle=False))
......
......@@ -2,7 +2,7 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python ./finetune_launch.py \
--nproc_per_node 8 \
......
......@@ -2,7 +2,7 @@ set -eux
export FLAGS_eager_delete_tensor_gb=0
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python ./finetune_launch.py \
--nproc_per_node 8 \
......
......@@ -155,8 +155,8 @@ def calc_em_score(answers, prediction):
def eval_file(dataset_file, prediction_file):
ground_truth_file = json.load(open(dataset_file, 'rb'))
prediction_file = json.load(open(prediction_file, 'rb'))
ground_truth_file = json.load(open(dataset_file, 'r'))
prediction_file = json.load(open(prediction_file, 'r'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM + F1) * 0.5
return EM, F1, AVG, TOTAL
......
import sys
import numpy as np
import re
from propeller import log
import itertools
from propeller.paddle.data import Dataset
import six
if six.PY2:
import operator
def accumulate(iterable, func=operator.add, initial=None):
'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
total = initial
if initial is None:
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = func(total, element)
yield total
else:
from itertools import accumulate
max_input_chars_per_word=100
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def wordpiece(token, vocab, unk_token, sentencepiece_style_vocab=False):
"""call with single word"""
chars = list(token)
if len(chars) > max_input_chars_per_word:
return [unk_token], [(0, len(chars))]
is_bad = False
start = 0
sub_tokens = []
sub_pos = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start == 0 and sentencepiece_style_vocab:
substr = u'\u2581' + substr
if start > 0 and not sentencepiece_style_vocab:
substr = "##" + substr
if substr in vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
sub_pos.append((start, end))
start = end
if is_bad:
return [unk_token], [(0, len(chars))]
else:
return sub_tokens, sub_pos
class SpaceTokenizer(object):
def __init__(self, vocab, lower=True):
"""
char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece
"""
self.vocab = set(vocab)
self.lower = lower
def __call__(self, sen):
if len(sen) == 0:
return [] #empty line
sen = sen.decode('utf8')
if self.lower:
sen = sen.lower()
res = []
for s in sen.split(' '):
if s == ' ':
continue
if s in self.vocab:
res.append(s)
else:
res.append('[UNK]')
return res
class CharTokenizer(object):
def __init__(self, vocab, lower=True):
"""
char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece
"""
self.vocab = set(vocab)
#self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
self.pat = re.compile(r'\S')
self.lower = lower
def __call__(self, sen):
if len(sen) == 0:
return [] #empty line
sen = sen.decode('utf8')
if self.lower:
sen = sen.lower()
res = []
for match in self.pat.finditer(sen):
words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]')
res.extend(words)
return res
def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
sen_emb = np.concatenate([[cls_id], seg_a, [sep_id], seg_b, [sep_id]], 0)
token_type_emb = np.concatenate([[0], token_type_a, [0], token_type_b, [1]], 0)
seqlen = sen_emb.shape[0]
#random truncate
random_begin = 0#np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]
return sen_emb, token_type_emb
def build_1_pair(seg_a, max_seqlen, cls_id, sep_id):
token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
sen_emb = np.concatenate([[cls_id], seg_a, [sep_id]], 0)
token_type_emb = np.concatenate([[0], token_type_a, [0]], 0)
seqlen = sen_emb.shape[0]
#random truncate
random_begin = 0#np.random.randint(0, np.maximum(0, seqlen - max_seqlen) + 1,)
sen_emb = sen_emb[random_begin: random_begin + max_seqlen]
token_type_emb = token_type_emb[random_begin: random_begin + max_seqlen]
return sen_emb, token_type_emb
def expand_dims(*args):
func = lambda i: np.expand_dims(i, -1)
ret = [func(i) for i in args]
return ret
def interleave(ds1, ds2):
def gen():
for i, j in six.moves.zip_longest(iter(ds1), iter(ds2)):
if i is not None:
yield i
if j is not None:
yield j
return Dataset.from_generator_func(gen)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册