未验证 提交 30b892e3 编写于 作者: M Meiyim 提交者: GitHub

Merge pull request #335 from Meiyim/dev

Dev
......@@ -3,21 +3,25 @@ English | [简体中文](./README.zh.md)
## ERNIE 2.0: A Continual Pre-training Framework for Language Understanding
* [Pre-training Tasks](#pre-training-tasks)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [Compare the ERNIE 1.0 and ERNIE 2.0](#compare-the-ernie-10-and-ernie-20)
* [Results on English Datasets](#results-on-english-datasets)
* [Results on Chinese Datasets](#results-on-chinese-datasets)
* [Pre-training Tasks](#pre-training-tasks)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [Compare the ERNIE 1.0 and ERNIE 2.0](#compare-the-ernie-10-and-ernie-20)
* [Results](#results)
* [Results on English Datasets](#results-on-english-datasets)
* [Results on Chinese Datasets](#results-on-chinese-datasets)
* [Release Notes](#release-notes)
* [Communication](#communication)
* [Usage](#usage)
![ernie2.0_paper](.metas/ernie2.0_paper.png)
......@@ -109,21 +113,6 @@ Integrating both phrase information and named entity information enables the mod
| **Structure-aware** | | ✅ Sentence Reordering | ✅ Sentence Reordering <br> ✅ Sentence Distance |
| **Semantic-aware** | ✅ Next Sentence Prediction | ✅ Discourse Relation | ✅ Discourse Relation <br> ✅ IR Relevance |
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
- Mar 15, 2019: release ERNIE 1.0
## Communication
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 760439550 (ERNIE discussion group).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Results
......@@ -626,6 +615,21 @@ LCQMC is a Chinese question semantic matching corpus published in COLING2018. [u
BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equivalence identification. This dataset was published in EMNLP 2018. [url: https://www.aclweb.org/anthology/D18-1536]
```
## Release Notes
- Aug 21, 2019: featuers update: fp16 finetuning, multiprocess finetining.
- July 30, 2019: release ERNIE 2.0
- Apr 10, 2019: update ERNIE_stable-1.0.1.tar.gz, update config and vocab
- Mar 18, 2019: update ERNIE_stable.tgz
- Mar 15, 2019: release ERNIE 1.0
## Communication
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- QQ discussion group: 760439550 (ERNIE discussion group).
- [Forums](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## Usage
* [Install PaddlePaddle](#install-paddlepaddle)
......@@ -645,7 +649,8 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv
* [Machine Reading Comprehension](#machine-reading-comprehension)
* [Pre-training with ERNIE 1.0](#pre-training-with-ernie-10)
* [Data Preprocessing](#data-preprocessing)
* [PreTrain ERNIE1.0](#pretrain-ernie10)
* [Pretrain ERNIE1.0](#pretrain-ernie10)
* [Distillation](#distillation)
* [FAQ](#faq)
* [FAQ1: How to get sentence/tokens embedding of ERNIE?](#faq1-how-to-get-sentencetokens-embedding-of-ernie)
* [FAQ2: How to predict on new data with Fine-tuning model?](#faq2-how-to-predict-on-new-data-with-fine-tuning-model)
......@@ -654,7 +659,7 @@ BQ Corpus (Bank Question corpus) is a Chinese corpus for sentence semantic equiv
* [FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.](#faq5-can-not-find-library-libncclso-please-try-to-add-the-lib-path-to-ld_library_path)
## Install PaddlePaddle
### Install PaddlePaddle
This code base has been tested with Paddle Fluid 1.5.1 under Python2.
......@@ -671,11 +676,15 @@ If you have been armed with certain level of deep learning knowledge, and it hap
For more information about paddlepadde, Please refer to [PaddlePaddle Github](https://github.com/PaddlePaddle/Paddle) or [Official Website](https://www.paddlepaddle.org.cn/) for details.
Other dependency of ERNIE is listed in `requirements.txt`, you can install it by
```script
pip install -r requirements.txt
```
## Pre-trained Models & Datasets
### Pre-trained Models & Datasets
### Models
#### Models
| Model | Description |
| :------------------------------------------------- | :----------------------------------------------------------- |
......@@ -685,23 +694,23 @@ For more information about paddlepadde, Please refer to [PaddlePaddle Github](ht
| [ERNIE 2.0 Base for English](https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz) | with params, config and vocabs |
| [ERNIE 2.0 Large for English](https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz) | with params, config and vocabs |
### Datasets
#### Datasets
#### English Datasets
##### English Datasets
Download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) and unpack it to some directory `${TASK_DATA_PATH}`
After the dataset is downloaded, you should run `sh ./script/en_glue/preprocess/cvt.sh $TASK_DATA_PATH` to convert the data format for training. If everything goes well, there will be a folder named `glue_data_processed` created with all the converted datas in it.
#### Chinese Datasets
##### Chinese Datasets
You can download Chinese Datasets from [here](https://ernie.bj.bcebos.com/task_data_zh.tgz)
## Fine-tuning
#### Fine-tuning
### Batchsize and GPU Settings
##### Batchsize and GPU Settings
In our experiments, we found that the batch size is important for different tasks. For users can more easily reproducing results, we list the batch size and gpu cards here:
......@@ -728,7 +737,7 @@ In our experiments, we found that the batch size is important for different task
\* *For MNLI, QNLI,we used 32GB V100, for other tasks we used 22GB P40*
### Multiprocessing and fp16 auto mix-precision finetune
#### Multiprocessing and fp16 auto mix-precision finetune
multiprocessing finetuning can be simply enabled with `finetune_launch.py` in your finetune script.
with multiprocessing finetune paddle can fully utilize your CPU/GPU capacity to accelerate finetuning.
......@@ -738,9 +747,9 @@ fp16 finetuning can be simply enable by specifing `--use_fp16 true` in your trai
dynamic loss scale is used to avoid gradient vanish.
### Classification
#### Classification
#### Single Sentence Classification Tasks
##### Single Sentence Classification Tasks
The code used to perform classification/regression finetuning is in `run_classifier.py`, we also provide the shell scripts for each task including best hyperpameters.
......@@ -798,7 +807,7 @@ Similarly, for the Chinese task `ChnSentCorp`, after setting the environment var
#### Sentence Pair Classification Tasks
##### Sentence Pair Classification Tasks
Take `RTE` as an example, the data should have 3 fields `text_a text_b label` with tsv format. Here is some example datas:
```
......@@ -834,9 +843,9 @@ testing ./data/test.tsv, save to output/test_out.5.2019-07-23-15-25-06.tsv.4.781
### Sequence Labeling
#### Sequence Labeling
#### Named Entity Recognition
##### Named Entity Recognition
Take `MSRA-NER(SIGHAN2006)` as an example, the data should have 2 fields, `text_a label`, with tsv format. Here is some example datas :
```
......@@ -853,7 +862,7 @@ Also, remember to set environmental variables like above, and run `sh script/zh_
[test evaluation] f1: 0.937390, precision: 0.925988, recall: 0.949077, elapsed time: 36.565929 s
```
### Machine Reading Comprehension
#### Machine Reading Comprehension
Take `DRCD` as an example, convert the data into SQUAD format firstly:
......@@ -896,9 +905,9 @@ Also, remember to set environmental variables like above, and run `sh script/zh_
```
## Pre-training with ERNIE 1.0
### Pre-training with ERNIE 1.0
### Data Preprocessing
#### Data Preprocessing
We construct the training dataset based on [Baidu Baike](https://en.wikipedia.org/wiki/Baidu_Baike), [Baidu Knows(Baidu Zhidao)](https://en.wikipedia.org/wiki/Baidu_Knows), [Baidu Tieba](https://en.wikipedia.org/wiki/Baidu_Tieba) for Chinese version ERNIE, and [Wikipedia](https://en.wikipedia.org/wiki/Wikipedia:Database_download), [Reddit](https://en.wikipedia.org/wiki/Reddit), [BookCorpus](https://github.com/soskek/bookcorpus) for English version ERNIE.
......@@ -912,7 +921,7 @@ Here are some train instances after processing (which can be found in [`data/dem
Each instance is composed of 5 fields, which are joined by `;`in one line, represented `token_ids; sentence_type_ids; position_ids; seg_labels; next_sentence_label` respectively. Especially, in the field`seg_labels`, 0 means the begin of one word, 1 means non-begin of one word, -1 means placeholder, the other number means `CLS` or `SEP`.
### PreTrain ERNIE 1.0
#### Pretrain ERNIE 1.0
The start entry for pretrain is [`script/zh_task/pretrain.sh`](./script/zh_task/pretrain.sh). Before we run the train program, remember to set CUDA、cuDNN、NCCL2 etc. in the environment variable LD_LIBRARY_PATH.
......@@ -932,10 +941,15 @@ epoch: 1, progress: 1/1, step: 50, loss: 10.360563, ppl: 16398.287109, next_sent
```
### Distillation
ERNIE provide a toolkit for data distillation to further accelerate your ineference, see <a href="./distill/README.md">here</a> for detail
## FAQ
### FAQ
### FAQ1: How to get sentence/tokens embedding of ERNIE?
#### FAQ1: How to get sentence/tokens embedding of ERNIE?
Run ```ernie_encoder.py ``` we can get the both sentence embedding and tokens embeddings. The input data format should be same as that mentioned in chapter [Fine-tuning](#fine-tuning).
......@@ -960,7 +974,7 @@ when finished running this script, `cls_emb.npy` and `top_layer_emb.npy `will b
### FAQ2: How to predict on new data with Fine-tuning model?
#### FAQ2: How to predict on new data with Fine-tuning model?
Take classification tasks for example, here is the script for batch prediction:
......@@ -984,18 +998,18 @@ Argument `init_checkpoint` is the path of the model, `predict_set` is the path
### FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?
#### FAQ3: Is the argument batch_size for one GPU card or for all GPU cards?
For one GPU card.
### FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.
#### FAQ4: Can not find library: libcudnn.so. Please try to add the lib path to LD_LIBRARY_PATH.
Export the path of cuda to LD_LIBRARY_PATH, e.g.: `export LD_LIBRARY_PATH=/home/work/cudnn/cudnn_v[your cudnn version]/cuda/lib64`
### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
#### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
Download [NCCL2](https://developer.nvidia.com/nccl/nccl-download), and export the library path to LD_LIBRARY_PATH, e.g.:`export LD_LIBRARY_PATH=/home/work/nccl/lib`
......@@ -3,21 +3,25 @@
## ERNIE 2.0: A Continual Pre-training Framework for Language Understanding
* [Pre-Training 任务](#pre-training-任务)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [对比 ERNIE 1.0 和 ERNIE 2.0](#对比-ernie-10-和-ernie-20)
* [中文效果验证](#中文效果验证)
* [英文效果验证](#英文效果验证)
* [Pre-Training 任务](#pre-training-任务)
* [Word-aware Tasks](#word-aware-tasks)
* [Knowledge Masking Task](#knowledge-masking-task)
* [Capitalization Prediction Task](#capitalization-prediction-task)
* [Token-Document Relation Prediction Task](#token-document-relation-prediction-task)
* [Structure-aware Tasks](#structure-aware-tasks)
* [Sentence Reordering Task](#sentence-reordering-task)
* [Sentence Distance Task](#sentence-distance-task)
* [Semantic-aware Tasks](#semantic-aware-tasks)
* [Discourse Relation Task](#discourse-relation-task)
* [IR Relevance Task](#ir-relevance-task)
* [ERNIE 1.0: <strong>E</strong>nhanced <strong>R</strong>epresentation through k<strong>N</strong>owledge <strong>I</strong>nt<strong>E</strong>gration](#ernie-10-enhanced-representation-through-knowledge-integration)
* [对比 ERNIE 1.0 和 ERNIE 2.0](#对比-ernie-10-和-ernie-20)
* [效果验证](#效果验证)
* [中文效果验证](#中文效果验证)
* [英文效果验证](#英文效果验证)
* [开源记录](#开源记录)
* [技术交流](#技术交流)
* [使用](#使用)
![ernie2.0_paper](.metas/ernie2.0_paper.png)
......@@ -105,26 +109,16 @@
| **Semantic-aware** | ✅ Next Sentence Prediction | ✅ Discourse Relation | ✅ Discourse Relation <br> ✅ IR Relevance |
## 开源记录
- 2019-07-30 发布 ERNIE 2.0
- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布
- 2019-03-18 更新: update ERNIE_stable.tgz
- 2019-03-15 发布 ERNIE 1.0
## 技术交流
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- ERNIE QQ 群: 760439550 (ERNIE discussion group).
- [论坛](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## 效果验证
## 中文效果验证
### 中文效果验证
我们在 9 个任务上验证 ERNIE 2.0 中文模型的效果。这些任务包括:自然语言推断任务 XNLI;阅读理解任务 DRCD、DuReader、CMRC2018;命名实体识别任务 MSRA-NER (SIGHAN2006);情感分析任务 ChnSentiCorp;语义相似度任务 BQ Corpus、LCQMC;问答任务 NLPCC2016-DBQA 。任务的详情和效果会在如下章节中介绍。
### 自然语言推断任务
#### 自然语言推断任务
<table>
<tbody>
......@@ -189,7 +183,7 @@
XNLI 是由 Facebook 和纽约大学的研究者联合构建的自然语言推断数据集,包括 15 种语言的数据。我们用其中的中文数据来评估模型的语言理解能力。[链接: https://github.com/facebookresearch/XNLI]
```
### 阅读理解任务
#### 阅读理解任务
<table>
<tbody>
......@@ -318,9 +312,7 @@ CMRC2018 是中文信息学会举办的评测,评测的任务是抽取类阅
DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从篇章中抽取出连续片段作为答案。我们在实验时先将其转换成简体中文。[链接: https://github.com/DRCKnowledgeTeam/DRCD]
```
### 命名实体识别任务
#### 命名实体识别任务
<table>
<tbody>
......@@ -377,9 +369,7 @@ DRCD 是台达研究院发布的繁体中文阅读理解数据集,目标是从
MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,包括人名、地名、机构名。
```
### 情感分析任务
#### 情感分析任务
<table>
<tbody>
......@@ -436,9 +426,7 @@ MSRA-NER (SIGHAN2006) 数据集由微软亚研院发布,其目标是识别文
ChnSentiCorp 是一个中文情感分析数据集,包含酒店、笔记本电脑和书籍的网购评论。
```
### 问答任务
#### 问答任务
<table>
<tbody>
......@@ -512,9 +500,7 @@ ChnSentiCorp 是一个中文情感分析数据集,包含酒店、笔记本电
NLPCC2016-DBQA 是由国际自然语言处理和中文计算会议 NLPCC 于 2016 年举办的评测任务,其目标是从候选中找到合适的文档作为问题的答案。[链接: http://tcci.ccf.org.cn/conference/2016/dldoc/evagline2.pdf]
```
### 语义相似度
#### 语义相似度
<table>
<tbody>
......@@ -597,14 +583,14 @@ BQ Corpus 是在自然语言处理国际顶会 EMNLP 2018 发布的语义匹配
## 英文效果验证
### 英文效果验证
ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址为 https://gluebenchmark.com/ ,该评测涵盖了不同类型任务的 10 个数据集,其中包含 11 个测试集,涉及到 Accuracy, F1-score, Spearman Corr,. Pearson Corr,. Matthew Corr., 5 类指标。GLUE 排行榜使用每个数据集的平均分作为总体得分,并以此为依据将不同算法进行排名。
### GLUE - 验证集结果
#### GLUE - 验证集结果
| <strong>数据集</strong> | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>QNLI</strong> | <strong>RTE</strong> |
| ----------- | ---- | ----- | ---- | ----- | ---- | ---- | ---- | ---- |
......@@ -617,7 +603,7 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
### GLUE - 测试集结果
#### GLUE - 测试集结果
| <strong>数据集</strong> | - | <strong>CoLA</strong> | <strong>SST-2</strong> | <strong>MRPC</strong> | <strong>STS-B</strong> | <strong>QQP</strong> | <strong>MNLI-m</strong> | <strong>MNLI-mm</strong> | <strong>QNLI</strong> | <strong>RTE</strong> | <strong>WNLI</strong> |<strong>AX</strong>|
| ----------- | ----- | ---- | ----- | ---- | ----- | ---- | ------ | ------- | ---- | ---- | ---- | ---- |
......@@ -631,6 +617,19 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
由于 XLNet 暂未公布 GLUE 测试集上的单模型结果,所以我们只与 BERT 进行单模型比较。上表为ERNIE 2.0 单模型在 GLUE 测试集的表现结果。
## 开源记录
- 2019-07-30 发布 ERNIE 2.0
- 2019-04-10 更新: update ERNIE_stable-1.0.1.tar.gz, 将模型参数、配置 ernie_config.json、vocab.txt 打包发布
- 2019-03-18 更新: update ERNIE_stable.tgz
- 2019-03-15 发布 ERNIE 1.0
## 技术交流
- [Github Issues](https://github.com/PaddlePaddle/ERNIE/issues): bug reports, feature requests, install issues, usage issues, etc.
- ERNIE QQ 群: 760439550 (ERNIE discussion group).
- [论坛](http://ai.baidu.com/forum/topic/list/168?pageNo=1): discuss implementations, research, etc.
## 使用
* [PaddlePaddle 安装](#paddlepaddle安装)
* [模型&amp;数据](#模型数据)
......@@ -650,6 +649,10 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
* [预训练 (ERNIE 1.0)](#预训练-ernie-10)
* [数据预处理](#数据预处理)
* [开始训练](#开始训练)
* [蒸馏](#蒸馏)
* [上线](#上线)
* [生成inference_model](#生成inference_model)
* [在线预测](#在线预测)
* [FAQ](#faq)
* [FAQ1: 如何获取输入句子/词经过 ERNIE 编码后的 Embedding 表示?](#faq1-如何获取输入句子词经过-ernie-编码后的-embedding-表示)
* [FAQ2: 如何利用 Fine-tuning 得到的模型对新数据进行批量预测?](#faq2-如何利用-fine-tuning-得到的模型对新数据进行批量预测)
......@@ -672,6 +675,11 @@ ERNIE 2.0 的英文效果验证在 GLUE 上进行。GLUE 评测的官方地址
> - [训练神经网络](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/training/index_cn.html):介绍如何使用 Fluid 进行单机训练、多机训练、以及保存和载入模型变量
> - [模型评估与调试](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/evaluation_and_debugging/index_cn.html):介绍在 Fluid 下进行模型评估和调试的方法
ERNIE的其他依赖列在`requirements.txt`文件中,使用以下命令安装
```script
pip install -r requirements.txt
```
## 模型&数据
......@@ -918,6 +926,36 @@ epoch: 1, progress: 1/1, step: 50, loss: 10.360563, ppl: 16398.287109, next_sent
如果用自定义的真实数据进行训练,请参照[`script/zh_task/pretrain.sh`](./script/zh_task/pretrain.sh)脚本对参数做相应修改。
## 蒸馏
ERNIE提供了通过数据蒸馏从而达到模型压缩、加速的开发套件,具体开发流程请参考 <a href="./distill/README.md">这里</a>
## 上线
完成finetune之后只需几步操作即可生成inference\_model, PaddlePaddle可以在生产环境中加载生成的预测模型并进行高效地预测。
### 生成inference\_model
运行`classify_infer.py`或者`predict_classifier.py` 脚本时通过指定 `--save_inference_model_path` 便可生成 inference_model 到指定位置。
如果您采用 `propeller` 完成finetune,则 `BestInferenceExporter` 会在finetune过程中根据预测指标,挑最好的模型生成 inference_model . 使用 `propeller` 完成finetune的流程请参考 `propeller_xnli_demo.ipynb`
### 在线预测
随后您可以使用[PaddleInference C++ API](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/native_infer.html)将模型的前向预测代码联编到您的生产环境中。或者您可以使用我们为您构建好的python预测引擎来完成一个简单的服务。只需将本代码库中的 `./propeller` 文件夹放入您的 `PYTHONPATH` 中并执行如下指令,便可以开启一个propeller server:
```script
python -m propeller.tools.start_server -m /path/to/saved/model -p 8888
```
您可以在python脚本很方便地调用propeller server:
```python
from propeller.service.client import InferenceClient
client = InferenceClient('tcp://localhost:8113')
result = client(sentence_id, position_id, token_type_id, input_mask)
```
`client`的请求参数类型是numpy array,对应了save_inference_model时指定的输入tensor. 如果是使用`classify_infer.py` 生成的inference_model则请求参数有四个:(sentence_id, position_id, token_type_id, input_mask)。 如果是`propeller` 生成的inference_model, client的请求参数对应您`eval_dataset` 的元素类型。
## FAQ
### FAQ1: 如何获取输入句子/词经过 ERNIE 编码后的 Embedding 表示?
......@@ -981,3 +1019,4 @@ python -u predict_classifier.py \
### FAQ5: Can not find library: libnccl.so. Please try to add the lib path to LD_LIBRARY_PATH.
需要先下载 [NCCL](https://developer.nvidia.com/nccl/nccl-download),然后在 LD_LIBRARY_PATH 中添加 NCCL 库的路径,如`export LD_LIBRARY_PATH=/home/work/nccl/lib`
* [ERNIE Slim 数据蒸馏](#ernie-slim-数据蒸馏)
* [ERNIE数据蒸馏三步](#ernie数据蒸馏三步)
* [数据增强](#数据增强)
* [使用教程](#使用教程)
* [离线蒸馏](#离线蒸馏)
* [在线蒸馏](#在线蒸馏)
* [效果验证](#效果验证)
* [Case#1 用户提供“无标注数据”](#case1)
* [Case#2 用户未提供“无标注数据”](#case2)
* [FAQ](#faq)
# ERNIE Slim 数据蒸馏
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
<img src="http://agroup-bos.cdn.bcebos.com/ae16a29d6a334c74107cebcf56bc2419d385b364" title="ERNIE数据蒸馏示意图" width="900">
因此,如上图所示,我们基于[数据蒸馏技术](https://arxiv.org/pdf/1712.04440.pdf)构建了**ERNIE Slim数据蒸馏系统**。它的原理是通过数据作为桥梁,将ERNIE模型的知识迁移至小模型,以达到损失很小的效果却能达到上千倍的预测速度提升的效果。
### ERNIE数据蒸馏三步
- **Step 1**. 使用ERNIE模型对输入标注数据对进行fine-tune,得到Teacher Model
- **Step 2**. 使用ERNIE Service对以下无监督数据进行预测:
1. 用户提供的大规模无标注数据,需与标注数据同源
2. 对标注数据进行数据增强,具体增强策略见下节
3. 对无标注数据和数据增强数据进行一定比例混合
- **Step 3.** 使用步骤2的数据训练出Student Model
### 数据增强
目前采用三种[数据增强策略](https://arxiv.org/pdf/1903.12136.pdf)策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:
1. 添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签
2. 同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集钟随机一个同词性的词
3. N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值
# 使用教程
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从[这里](https://ernie.bj.bcebos.com/distill_data.tar.gz)下载。将下载的 `distill` 文件夹放入 `${TASK_DATA_PATH}` 后即可执行下面的脚本开始蒸馏。
### 离线蒸馏
离线蒸馏指的是先通过训练好的ERNIE模型预测出无监督数据的label,然后student模型去学习这些label。只需执行
```script
sh ./distill/script/distill_chnsenticorp.sh
```
即可开始离线蒸馏。
该脚本会进行前述的三步:1. 在任务数据上Fine-tune。 2. 加载Fine-tune好的模型对增强数据进行打分。 3.使用Student模型进行训练。脚本采用hard-label蒸馏,在第二步中将会直接预测出ERNIE标注的label。
该脚本涉及两个python文件:`./distill/finetune_chnsenticorp.py` 负责finetune以及预测teacher模型, `distill/distill_chnsentocorp.py` 负责student模型的训练。事先构造好的增强数据放在`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug`
在脚本的第二步中,使用 `--do_predict` 参数进入预测模式:
```script
cat ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 |python3 -u ./distill/finetune_chnsenticorp.py \
--do_predict \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
...
```
脚本从标准输入获取明文输入,并将打分输出到标准输出。用这种方式对数据增强后的无监督训练预料进行标注。最终的标注结果放在 `prediction_output/part.0` 文件中。标注结果包含两列, 第一列为明文,第二列为标注label。
在第三步开始student模型的训练:
```script
python3 ./distill/distill_chnsentocorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--unsupervise_data_dir ./prediction_output/ \
--max_seqlen 128 \
...
```
训练流程与第一步相同,`--data_dir` 指定的监督数据,`--unsupervise_data_dir` 指定ERNIE标注数据。Student模型是一个简单的BOW模型,其定义位于`distill/distill_chnsentocorp.py`。用户只需改写其中的model部分即可实现定制蒸馏模型。
如果用户已经拥有了无监督数据,则可以将无监督数据放入 `${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug` 即可。
### 在线蒸馏
考虑到在某些场景下,无监督数据过大导致预测过程十分耗时,或者ERNIE预测出的分布过大而无法预先存放在磁盘中。针对这种场景我们提出一种 **在线蒸馏** 方案。采用`propeller` 进行fine-tune并使用 `BestInferenceModelExporter` 后,`propeller` 会自动将指标最好的模型保存为paddle inference model格式,随后启动一个预测服务。Student模型在训练的同时,实时地访问这个服务来获得ERNIE的预测打分。只需执行
```
sh ./distill/script/distill_chnsenticorp_with_propeller_server.sh
```
即可完成上述流程。
流程包含3步:1. finetune ERNIE模型。2. 取指标最好的ERNIE模型启动`propeller`服务。 3.在student模型的训练过程中访问服务获取teacher模型的标注。
此流程涉及两个python文件: `distill/finetune_chnsenticorp.py``distill/distill_chnsentocorp_with_propeller_server.py` 。其中第一步与离线蒸馏中的用法完全一样。
第二步中使用
```script
python3 -m propeller.tools.start_server -p 8113 -m ${teacher_dir}/best/inference/ &
```
启动一个ernie预测服务
第三步开始student模型的同步训练:
```script
python3 ./distill/distill_chnsentocorp_with_propeller_server.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--teacher_vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--teacher_max_seqlen 128 \
--server_batch_size 64 \
--teacher_host tcp://localhost:8113 \
--num_coroutine 10
```
该脚本将`${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug` 目录下的增强数据进行切字并请求`propeller` 服务。`--num_coroutine` 指定了请求的并发数,`--teacher_host` 指定了服务的端口和IP,`--server_batch_size` 指定了请求的batch_size,在实际的请求中每个batch的数据会拆分成若干个 `--server_batch_size` 大小的数据去请求服务。
# 效果验证
我们将实际应用场景分类为两种:
### Case#1 用户提供“无标注数据”<a name="case1"></a>
|模型 | 评论低质识别【分类 \| ACC】 | 中文情感【分类 \| ACC】 |问题识别【分类 \| ACC】|搜索问答匹配【匹配 \| 正逆序】|
|---|---|---|---|---|
|ERNIE-Finetune | 90.6% | 96.2% | 97.5% | 4.25 |
|非ERNIE基线(BOW)| 80.8% | 94.7% | 93.0% | 1.83 |
|**+ 数据蒸馏** | 87.2% | 95.8% | 96.3% | 3.30 |
### Case#2 用户未提供“无标注数据”(通过数据增强生成数据)<a name="case2"></a>
|模型 |ChnSentiCorp |
|---|---|
|ERNIE-Finetune |95.4% |
|非ERNIE基线(BOW)|90.1%|
|**+ 数据蒸馏** |91.4%|
|非ERNIE基线(LSTM)|91.2%|
|**+ 数据蒸馏**|93.9%|
# FAQ
### FQA1: 预测同时蒸馏报错:`Client call failed`
终端打印的错误是client的日志,server端的日志在前面。一般来说可能是server显存超限导致。这种时候需要在student模型finetune的脚本中使用`--server_batch_size ` 显示控制请求服务的batch大小。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
from random import random
from functools import reduce, partial
import logging
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import log
import propeller.paddle as propeller
from propeller.paddle.data import Dataset
from optimization import optimization
import utils.data
log.setLevel(logging.DEBUG)
class ClassificationBowModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, config, mode, run_config):
self.config = config
self.mode = mode
self.run_config = run_config
self._param_initializer = F.initializer.TruncatedNormal(
scale=config.initializer_range)
self._emb_dtype = "float32"
self._word_emb_name = "word_embedding"
def forward(self, features):
text_ids_a, = features
def bow(ids):
embed = L.embedding(
input=ids,
size=[self.config.vocab_size, self.config.emb_size],
dtype=self._emb_dtype,
param_attr=F.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
zero = L.fill_constant(shape=[1], dtype='int64', value=0)
pad = L.cast(L.logical_not(L.equal(ids, zero)), 'float32')
sumed = L.reduce_sum(embed * pad, dim=1)
sumed = L.softsign(sumed)
return sumed
sumed = bow(text_ids_a)
fced = L.fc(
input=sumed,
size=self.config.emb_size,
act='tanh',
param_attr=F.ParamAttr(
name="middle_fc.w_0", initializer=self._param_initializer),
bias_attr="middle_fc.b_0")
logits = L.fc(
input=fced,
size=self.config.num_label,
act=None,
param_attr=F.ParamAttr(
name="pooler_fc.w_0", initializer=self._param_initializer),
bias_attr="pooler_fc.b_0")
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
labels = L.softmax(labels)
loss = L.softmax_with_cross_entropy(predictions, labels, soft_label=True)
loss = L.mean(loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.config.warmup_proportion),
num_train_steps=self.run_config.max_steps,
learning_rate=self.config.learning_rate,
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.config.weight_decay,
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, labels):
predictions = L.argmax(predictions, axis=1)
labels = L.argmax(labels, axis=1)
#predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(labels, predictions)
#auc = propeller.metrics.Auc(labels, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--unsupervise_data_dir', type=str, required=True)
parser.add_argument('--data_dir', type=str)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
unk_id = vocab['[UNK]']
char_tokenizer = utils.data.CharTokenizer(vocab.keys())
space_tokenizer = utils.data.SpaceTokenizer(vocab.keys())
supervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.LabelColumn('label'),
])
def before(text_a, label):
sentence_a = text_a[: args.max_seqlen]
return sentence_a, label
def after(sentence_a, label):
batch_size = sentence_a.shape[0]
onehot_label = np.zeros([batch_size, hparams.num_label], dtype=np.float32)
onehot_label[np.arange(batch_size), label] = 9999.
sentence_a, = utils.data.expand_dims(sentence_a)
return sentence_a, onehot_label
train_ds = supervise_feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after) \
unsup_train_ds = supervise_feature_column.build_dataset('unsup_train', data_dir=args.unsupervise_data_dir, shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
dev_ds = supervise_feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
train_ds = utils.data.interleave(train_ds, unsup_train_ds)
shapes = ([-1, args.max_seqlen, 1], [-1, hparams.num_label])
types = ('int64', 'float32')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['acc'] > old['dev']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationBowModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset={'dev': dev_ds},
exporters=[best_exporter])
print('dev_acc3\t%.5f' % (best_exporter._best['dev']['acc']))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
from random import random
from functools import reduce, partial
import logging
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import log
import propeller.paddle as propeller
from propeller.paddle.data import Dataset
from propeller.service.client import InferenceClient
from optimization import optimization
import utils.data
log.setLevel(logging.DEBUG)
class ClassificationBowModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, config, mode, run_config):
self.config = config
self.mode = mode
self.run_config = run_config
self._param_initializer = F.initializer.TruncatedNormal(
scale=config.initializer_range)
self._emb_dtype = "float32"
self._word_emb_name = "word_embedding"
def forward(self, features):
text_ids_a, = features
def bow(ids):
embed = L.embedding(
input=ids,
size=[self.config.vocab_size, self.config.emb_size],
dtype=self._emb_dtype,
param_attr=F.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
zero = L.fill_constant(shape=[1], dtype='int64', value=0)
pad = L.cast(L.logical_not(L.equal(ids, zero)), 'float32')
sumed = L.reduce_sum(embed * pad, dim=1)
sumed = L.softsign(sumed)
return sumed
sumed = bow(text_ids_a)
fced = L.fc(
input=sumed,
size=self.config.emb_size,
act='tanh',
param_attr=F.ParamAttr(
name="middle_fc.w_0", initializer=self._param_initializer),
bias_attr="middle_fc.b_0")
logits = L.fc(
input=fced,
size=self.config.num_label,
act=None,
param_attr=F.ParamAttr(
name="pooler_fc.w_0", initializer=self._param_initializer),
bias_attr="pooler_fc.b_0")
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
labels = L.softmax(labels)
loss = L.softmax_with_cross_entropy(predictions, labels, soft_label=True)
loss = L.mean(loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.config.warmup_proportion),
num_train_steps=self.run_config.max_steps,
learning_rate=self.config.learning_rate,
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.config.weight_decay,
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, labels):
predictions = L.argmax(predictions, axis=1)
labels = L.argmax(labels, axis=1)
#predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(labels, predictions)
#auc = propeller.metrics.Auc(labels, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--teacher_vocab_file', type=str, required=True)
parser.add_argument('--teacher_max_seqlen', type=int, default=128)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--server_batch_size', type=int, default=64)
parser.add_argument('--num_coroutine', type=int, default=1)
parser.add_argument('--teacher_host', type=str, required=True)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
teacher_vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.teacher_vocab_file, 'rb'))}
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
teacher_sep_id = teacher_vocab['[SEP]']
teacher_cls_id = teacher_vocab['[CLS]']
teacher_unk_id = teacher_vocab['[UNK]']
unk_id = vocab['[UNK]']
char_tokenizer = utils.data.CharTokenizer(vocab.keys())
space_tokenizer = utils.data.SpaceTokenizer(vocab.keys())
supervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.LabelColumn('label'),
])
unsupervise_feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('text_a', unk_id=unk_id, vocab_dict=vocab, tokenizer=space_tokenizer),
propeller.data.TextColumn('teacher_text_a', unk_id=teacher_unk_id, vocab_dict=teacher_vocab, tokenizer=char_tokenizer),
])
def before(text_a, label):
sentence_a = text_a[: args.max_seqlen]
return sentence_a, label
def after(sentence_a, label):
batch_size = sentence_a.shape[0]
onehot_label = np.zeros([batch_size, hparams.num_label], dtype=np.float32)
onehot_label[np.arange(batch_size), label] = 9999.
sentence_a, = utils.data.expand_dims(sentence_a)
return sentence_a, onehot_label
train_ds = supervise_feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after) \
dev_ds = supervise_feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
def unsuperve_before(text_a, teacher_text_a):
teacher_sentence, teacher_segments = utils.data.build_1_pair(teacher_text_a, max_seqlen=args.teacher_max_seqlen, cls_id=teacher_cls_id, sep_id=teacher_sep_id)
sentence_a = text_a[: args.max_seqlen]
return sentence_a, teacher_sentence, teacher_segments
client = InferenceClient(args.teacher_host, batch_size=args.server_batch_size, num_coroutine=args.num_coroutine)
log.info('teacher host %s' % args.teacher_host)
def ask_teacher_for_label(sentence_a, teacher_sentence, teacher_segments):
sentence_a, teacher_sentence, teacher_segments = utils.data.expand_dims(sentence_a, teacher_sentence, teacher_segments)
teacher_label, = client(teacher_sentence, teacher_segments)
teacher_label = teacher_label[:, :]
return sentence_a, teacher_label
unsup_train_ds = unsupervise_feature_column.build_dataset('unsup_train', data_dir=os.path.join(args.data_dir, 'unsup_train_aug'), shuffle=True, repeat=True, use_gz=False) \
.buffered(100) \
.map(unsuperve_before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(ask_teacher_for_label)
train_ds = utils.data.interleave(train_ds, unsup_train_ds)
shapes = ([-1, args.max_seqlen, 1], [-1, hparams.num_label])
types = ('int64', 'float32')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
'''
from tqdm import tqdm
for slots in tqdm(train_ds):
pass
'''
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['acc'] > old['dev']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationBowModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset={'dev': dev_ds},
exporters=[best_exporter])
print('dev_acc3\t%.5f' % (best_exporter._best['dev']['acc']))
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
import logging
from random import random
from functools import reduce, partial
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
from model.ernie import ErnieModel
from optimization import optimization
import utils.data
from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
class ClassificationErnieModel(propeller.train.Model):
"""propeller Model wraper for paddle-ERNIE """
def __init__(self, hparam, mode, run_config):
self.hparam = hparam
self.mode = mode
self.run_config = run_config
def forward(self, features):
src_ids, sent_ids = features
dtype = 'float16' if self.hparam['fp16'] else 'float32'
zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.logical_not(L.equal(src_ids, zero)), dtype) # assume pad id == 0
#input_mask = L.unsqueeze(input_mask, axes=[2])
d_shape = L.shape(src_ids)
seqlen = d_shape[1]
batch_size = d_shape[0]
pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0])
pos_ids = L.expand(pos_ids, [batch_size, 1])
pos_ids = L.unsqueeze(pos_ids, axes=[2])
pos_ids = L.cast(pos_ids, 'int64')
pos_ids.stop_gradient = True
input_mask.stop_gradient = True
task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment
task_ids.stop_gradient = True
bert = ErnieModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
task_ids=task_ids,
input_mask=input_mask,
config=self.hparam,
use_fp16=self.hparam['fp16']
)
cls_feats = bert.get_pooled_output()
cls_feats = L.dropout(
x=cls_feats,
dropout_prob=0.1,
dropout_implementation="upscale_in_train"
)
logits = L.fc(
input=cls_feats,
size=self.hparam['num_label'],
param_attr=F.ParamAttr(
name="cls_out_w",
initializer=F.initializer.TruncatedNormal(scale=0.02)),
bias_attr=F.ParamAttr(
name="cls_out_b", initializer=F.initializer.Constant(0.))
)
propeller.summary.histogram('pred', logits)
if self.mode is propeller.RunMode.PREDICT:
probs = L.softmax(logits)
return probs
else:
return logits
def loss(self, predictions, labels):
ce_loss, probs = L.softmax_with_cross_entropy(
logits=predictions, label=labels, return_softmax=True)
#L.Print(ce_loss, message='per_example_loss')
loss = L.mean(x=ce_loss)
return loss
def backward(self, loss):
scheduled_lr, _ = optimization(
loss=loss,
warmup_steps=int(self.run_config.max_steps * self.hparam['warmup_proportion']),
num_train_steps=self.run_config.max_steps,
learning_rate=self.hparam['learning_rate'],
train_program=F.default_main_program(),
startup_prog=F.default_startup_program(),
weight_decay=self.hparam['weight_decay'],
scheduler="linear_warmup_decay",)
propeller.summary.scalar('lr', scheduled_lr)
def metrics(self, predictions, label):
predictions = L.argmax(predictions, axis=1)
predictions = L.unsqueeze(predictions, axes=[1])
acc = propeller.metrics.Acc(label, predictions)
#auc = propeller.metrics.Auc(label, predictions)
return {'acc': acc}
if __name__ == '__main__':
parser = propeller.ArgumentParser('DAN model with Paddle')
parser.add_argument('--max_seqlen', type=int, default=128)
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--warm_start_from', type=str)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
vocab = {j.strip().split(b'\t')[0].decode('utf8'): i for i, j in enumerate(open(args.vocab_file, 'rb'))}
sep_id = vocab['[SEP]']
cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]']
tokenizer = utils.data.CharTokenizer(vocab.keys())
def tokenizer_func(inputs):
'''avoid pickle error'''
ret = tokenizer(inputs)
return ret
if not args.do_predict:
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('title',unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),
propeller.data.LabelColumn('label'),
])
def before(seg_a, label):
sentence, segments = utils.data.build_1_pair(seg_a, max_seqlen=args.max_seqlen, cls_id=cls_id, sep_id=sep_id)
return sentence, segments, label
def after(sentence, segments, label):
sentence, segments, label = utils.data.expand_dims(sentence, segments, label)
return sentence, segments, label
log.debug(os.path.join(args.data_dir, 'train'))
train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(after)
dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0, 0)) \
.map(after)
shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1], [-1, 1])
types = ('int64', 'int64', 'int64')
train_ds.data_shapes = shapes
train_ds.data_types = types
dev_ds.data_shapes = shapes
dev_ds.data_types = types
varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*')
warm_start_dir = args.warm_start_from
ws = propeller.WarmStartSetting(
predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
from_dir=warm_start_dir
)
best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['eval']['acc'] > old['eval']['acc'])
propeller.train.train_and_eval(
model_class_or_model_fn=ClassificationErnieModel,
params=hparams,
run_config=run_config,
train_dataset=train_ds,
eval_dataset=dev_ds,
warm_start_setting=ws,
exporters=[best_exporter])
print('dev_acc\t%.5f' % (best_exporter._best['eval']['acc']))
else:
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('title',unk_id=unk_id, vocab_dict=vocab, tokenizer=tokenizer_func),
propeller.data.LabelColumn('label'),
])
def before(seg_a):
sentence, segments = utils.data.build_1_pair(seg_a, max_seqlen=args.max_seqlen, cls_id=cls_id, sep_id=sep_id)
return sentence, segments
def after(sentence, segments):
sentence, segments = utils.data.expand_dims(sentence, segments)
return sentence, segments
predict_ds = feature_column.build_dataset_from_stdin('predict') \
.map(before) \
.padded_batch(hparams.batch_size, (0, 0)) \
.map(after)
shapes = ([-1, args.max_seqlen, 1], [-1, args.max_seqlen, 1])
types = ('int64', 'int64')
predict_ds.data_shapes = shapes
predict_ds.data_types = types
finetuned_model = propeller.Learner(ClassificationErnieModel, run_config, hparams)
for logits, in finetuned_model.predict(predict_ds, ckpt=-1): # ckpt=-1 means last step
print(np.argmax(logits))
set -x
export PYTHONPATH=.:$PYTHONPATH
output_dir=./output/distill
teacher_dir=${output_dir}/teacher
student_dir=${output_dir}/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES=0 \
python3 -u ./distill/finetune_chnsenticorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"max_steps": '$((10 * 9600 / 32))',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
(($?!=0)) && echo "Something goes wrong at Step 1, please check" && exit -1
# 2. start a prediction server
export CUDA_VISIBLE_DEVICES=0
cat ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 |awk -F"\t" '{print $2}' |python3 -u ./distill/finetune_chnsenticorp.py \
--do_predict \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"log_steps": 10,
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 100
}' > prediction_label
(($?!=0)) && echo "Something goes wrong at Step 2, please check" && exit -1
mkdir prediction_output
paste ${TASK_DATA_PATH}/distill/chnsenticorp/student/unsup_train_aug/part.0 prediction_label |awk -F"\t" '{print $2"\t"$3}' > prediction_output/part.0
#. 3. learn from teacher
export CUDA_VISIBLE_DEVICES=0
python3 ./distill/distill_chnsentocorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--unsupervise_data_dir ./prediction_output/ \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${student_dir}'",
"max_steps": '$((100 * 9600 / 100))',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam '{
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}' \
--hparam '{ # lr shit
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
(($?!=0)) && echo "Something goes wrong at Step 3, please check" && exit -1
set -x
export PYTHONPATH=.:$PYTHONPATH
output_dir=./output/distill
teacher_dir=${output_dir}/teacher
student_dir=${output_dir}/student
# 1. finetune teacher
CUDA_VISIBLE_DEVICES=0 \
python3 -u ./distill/finetune_chnsenticorp.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/teacher \
--warm_start_from ${MODEL_PATH}/params \
--vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--run_config '{
"model_dir": "'${teacher_dir}'",
"max_steps": '$((10 * 9600 / 32))',
"save_steps": 100,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam ${MODEL_PATH}/ernie_config.json \
--hparam '{ # model definition
"sent_type_vocab_size": None, # default term in official config
"use_task_id": False,
"task_id": 0,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"fp16": 0,
"learning_rate": 0.00005,
"num_label": 2,
"batch_size": 32
}'
(($?!=0)) && echo "Something goes wrong at Step 1, please check" && exit -1
# 2. start a prediction server
CUDA_VISIBLE_DEVICES=1 \
python3 -m propeller.tools.start_server -p 8113 -m ${teacher_dir}/best/inference/ &
echo $! > pid.server
sleep 10
#. 3. learn from teacher
export CUDA_VISIBLE_DEVICES=0
python3 ./distill/distill_chnsentocorp_with_propeller_server.py \
--data_dir ${TASK_DATA_PATH}/distill/chnsenticorp/student \
--vocab_file ${TASK_DATA_PATH}/distill/chnsenticorp/student/vocab.txt \
--teacher_vocab_file ${MODEL_PATH}/vocab.txt \
--max_seqlen 128 \
--teacher_max_seqlen 128 \
--server_batch_size 64 \
--teacher_host tcp://localhost:8113 \
--num_coroutine 10 \
--run_config '{
"model_dir": "'${student_dir}'",
"max_steps": '$((100 * 9600 / 100))',
"save_steps": 1000,
"log_steps": 10,
"max_ckpt": 1,
"skip_steps": 0,
"eval_steps": 100
}' \
--hparam '{ # model definition
"num_label": 2,
"vocab_size": 35000,
"emb_size": 128,
"initializer_range": 0.02,
}' \
--hparam '{ # learn
"warmup_proportion": 0.1,
"weight_decay": 0.00,
"fp16": 0,
"learning_rate": 1e-4,
"batch_size": 100
}'
(($?!=0)) && echo "Something goes wrong at Step 2, please check" && exit -1
ps -ef|grep 'propeller.tools.start_server' |awk '{print $2}'|xargs kill -9
......@@ -212,6 +212,5 @@ def predict(exe,
except fluid.core.EOFException:
test_pyreader.reset()
break
log.info(len(res))
return res
......@@ -78,7 +78,7 @@ data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("predict_batch_size", int, 8, "Total examples' number in batch for predict. see also --in_tokens.")
data_g.add_arg("predict_batch_size", int, None, "Total examples' number in batch for predict. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
......
......@@ -137,14 +137,20 @@ def start_procs(args):
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
logdir = "%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id)
try:
os.mkdir(os.path.dirname(logdir))
except OSError:
pass
fn = open(logdir, "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
log.info('subprocess launched, check log at %s' % logdir)
else:
process = subprocess.Popen(cmd, env=current_env)
log.info('subprocess launched')
log.info('subprocess launched')
procs.append(process)
try:
......
......@@ -139,12 +139,19 @@ def start_procs(args):
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
logdir = "%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id)
try:
os.mkdir(os.path.dirname(logdir))
except OSError:
pass
fn = open(logdir, "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
log.info('subprocess launched, check log at %s' % logdir)
else:
process = subprocess.Popen(cmd, env=current_env)
log.info('subprocess launched')
log.info('subprocess launched')
procs.append(process)
try:
......
[简体中文](./README.md)|English
# Introducing Propeller
This doc introduct Propeller, a high level paddle API for general ML, Propeller encapsulate the following actions::
- training
- evaluation
- prediction
- export serving
Propeller provide the following benefits:
- You can run Propeller-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Propeller-based models on CPUs, GPUs without recoding your model.
- Propeller simplify sharing implementations between model developers.
- Propeller do many things for you (logging, hot-start...)
- Propeller buids Program and PyReader or you.
- Propeller provide a safe distributed training loop that controls how and when to:
- build the Program
- initialize variables
- create checkpoint files and recover from failures
- save visualizable results
## install
```script
cd propeller && pip install .
```
## Getting Started
```python
#Define model
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size'])
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# hyper param comes from files/command line prompt/env vir
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# Define data
# `FeatureColumns` helps you to organize training/evluation files.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# Start training!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
More detail see example/toy/
## Main Feature
1. train_and_eval
according to user-specified `propeller.Model`class,initialize training model in the following 2 modes: 1. TRAIN mode 2. EVAL mode and
perform train_and_eval
2. FeatureColumns
`FeatureColumns`is used to ogranize train data. With custmizable `Column` property, it can adaps to many ML tasks(NLP/CV...).
`FeatureColumns` also do the preprocessing for you (tokenization, vocab lookup, serialization, batcing etc.)
3. Dataset
`FeatureColumns` generats `Dataset`,or you can call `propeller.Dataset.from_generator_func` to build your own `Dataset`.
4. Summary
To trace tensor histogram in training, simply:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing
1. This project is in alpha stage, any contribution is welcomed. Fill free to create a PR.
简体中文|[English](./README.en.md)
# Introducing paddle-propeller
本文档介绍propeller,一种可极大地简化机器学习编程的高阶 Paddle API。propeller 会封装下列操作:
- 训练
- 评估
- 预测
- 导出以供使用(上线)
Propeller 具有下列优势:
- 您可以在本地主机上或分布式多服务器环境中运行基于 Propeller 的模型,而无需更改模型。此外,您可以在 CPU、GPU上运行基于 Propeller 的模型,而无需重新编码模型。
- Propeller 简化了在模型开发者之间共享实现的过程。
- 只需关注模型实现以及数据输入,而无需关注其他辅助代码(保存、热启动、打log等)
- Propeller 会为您构建Program以及PyReader。
- Propeller 提供安全的分布式训练循环,可以控制如何以及何时:
- 构建Program
- 初始化变量
- 处理异常
- 创建检查点文件并从故障中恢复
- 保存可视化的摘要结果
## install|安装
cd propeller && pip install .
## Getting Started|快速开始
```python
#定义训练模型
class BowModel(propeller.Model):
def __init__(self, config, mode):
self.embedding = Embedding(config['emb_size'], config['vocab_size'])
self.fc1 = FC(config['hidden_size'])
self.fc2 = FC(config['hidden_size']
def forward(self, features):
q, t = features
q_emb = softsign(self.embedding(q))
t_emb = softsign(self.embedding(t))
q_emb = self.fc1(q_emb)
t_emb = self.fc2(t_emn)
prediction = dot(q_emb, emb)
return prediction
def loss(self, predictions, label):
return sigmoid_cross_entropy_with_logits(predictions, label)
def backward(self, loss):
opt = AdamOptimizer(1.e-3)
opt.mimize(loss)
def metrics(self, predictions, label):
auc = atarshi.metrics.Auc(predictions, label)
return {'auc': auc}
# 超参可以来自于文件/ 环境变量/ 命令行
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
# 定义数据:
# `FeatureColumns` 用于管理训练、预测文件. 会自动进行二进制化.
feature_column = propeller.data.FeatureColumns(columns=[
propeller.data.TextColumn('query', vocab='./vocab'),
propeller.data.TextColumn('title', vocab='./vocab'),
propeller.data.LabelColumn('label'),
])
train_ds = feature_column.build_dataset(data_dir='./data', shuffle=True, repeat=True)
eval_ds = feature_column.build_dataset(data_dir='./data', shuffle=False, repeat=False)
# 开始训练!
propeller.train_and_eval(BowModel, hparams, run_config, train_ds, eval_ds)
```
先洗详细请见example/toy/
## 主要构件
1. train_and_eval
会根据用户提供的`propeller.Model`类,实例化两种模式下的训练模型: 1. TRAIN模式 2. EVAL模式。
然后开始训练,同时执行评估(Evaluation)
2. FeatureColumns
`FeatureColumns`来管理训练数据. 根据自定义`Column`来适配多种ML任务(NLP/CV...).
`FeatureColumns`会自动对提供的训练数据进行批量预处理(tokenization, 查词表, etc.)并二进制化,并且生成训练用的dataset
3. Dataset
`FeatureColumns`生成`Dataset`,或者您可以调用`propeller.Dataset.from_generator_func`来构造自己的`Dataset`,配合shuffle/ interleave/ padded_batch/ repeat 等方法满足定制化需求.
4. Summary
对训练过程中的某些参数进行log追踪,只需要:
```python
propeller.summary.histogram('loss', tensor)
```
## Contributing|贡献
1. 本项目处于初期阶段,欢迎贡献!
2. functional programing is welcomed
## TODO
1. dataset output_types/ output_shapes 自动推断
2. 自动超参数搜索
3. propeller server
4. ...
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
import six
from time import time
__version__ = '0.1'
log = logging.getLogger(__name__)
stream_hdl = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
try:
from colorlog import ColoredFormatter
fancy_formatter = ColoredFormatter(
fmt='%(log_color)s[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s'
)
stream_hdl.setFormatter(fancy_formatter)
except ImportError:
stream_hdl.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(stream_hdl)
log.propagate = False
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import logging
import os
import itertools
import random
import inspect
import multiprocessing
from contextlib import contextmanager
import gzip
import struct
import functools
import six
from six.moves import zip, map, filter
import numpy as np
from propeller.util import map_structure
log = logging.getLogger(__name__)
__all__ = ['Dataset']
@contextmanager
def open_file(filename, format=None):
if format is None:
fd = open(filename, 'rb')
elif format == 'GZIP':
fd = gzip.open(filename, 'rb')
else:
raise ValueError('unkwon file format %s' % format)
yield fd
fd.close()
def open_record(filename):
def gen():
with open_file(filename, format='GZIP') as f:
while True:
data = f.read(struct.calcsize('i'))
if not len(data):
raise StopIteration
l, = struct.unpack('i', data)
data = f.read(l)
yield data
return gen
def shuffle_func(dataset, buffer_size):
def gen():
buf = []
iterable = dataset()
try:
while len(buf) < buffer_size:
buf.append(next(iterable))
while 1:
i = random.randint(0, buffer_size - 1)
n = next(iterable)
yield buf[i]
buf[i] = n
except StopIteration:
if len(buf):
random.shuffle(buf)
for i in buf:
yield i
return gen
def interleave_func(iterable, map_fn, cycle_length, block_length):
def gen():
ls = itertools.tee(iterable(), cycle_length)
buf = []
for i, j in enumerate(ls):
j = itertools.islice(j, i, None, cycle_length)
j = map(map_fn, j)
j = (jjj for jj in j for jjj in jj) #flatten
buf.append(j)
for tup in six.moves.zip_longest(*buf):
for ii in (i for i in tup if i is not None):
yield ii
return gen
def repeat_func(dataset, n):
def gen():
iterable = dataset()
if n >= 0:
ret = itertools.chain(*itertools.tee(iterable, n))
else:
ret = itertools.cycle(iterable)
for i in ret:
yield i
return gen
def filter_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
if fn(*i) is True:
yield i
else:
if fn(i) is True:
yield i
return gen
def map_func(dataset, fn):
def gen():
for i in dataset():
if isinstance(i, tuple) or isinstance(i, list):
yield fn(*i)
else:
yield fn(i)
return gen
def shard_func(dataset, num_shards, index):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, index, None, num_shards)
for i in ret:
yield i
return gen
def take_func(dataset, count):
def gen():
iterable = dataset()
ret = itertools.islice(iterable, count)
for i in ret:
yield i
return gen
def buffered_func(dataset, size):
"""
Creates a buffered data reader.
The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty.
:param reader: the data reader to read from.
:type reader: callable
:param size: max buffer size.
:type size: int
:returns: the buffered data reader.
"""
class EndSignal():
pass
end = EndSignal()
def read_worker(r, q):
for d in r:
q.put(d)
q.put(end)
def data_reader():
r = dataset()
q = multiprocessing.Queue(maxsize=size)
t = multiprocessing.Process(
target=read_worker, args=(
r,
q, ))
t.daemon = True
t.start()
e = q.get()
while e != end:
yield e
e = q.get()
return data_reader
def padded_batch_func(dataset, batch_size, pad_value=0, max_seqlen=None):
if not isinstance(batch_size, int):
raise ValueError('unknown batch_size: %s' % repr(batch_size))
def gen():
iterable = dataset()
pad_value_t = pad_value
while True:
buf = list(itertools.islice(iterable, batch_size))
if not len(buf):
raise StopIteration
buf = list(zip(*buf)) # transpose
if type(pad_value_t) not in [list, tuple]:
pad_value_t = [pad_value_t] * len(buf)
padded = []
assert len(buf) == len(
pad_value_t), 'pad_value [%d] != element size[%d]' % (
len(pad_value_t), len(buf))
for e, pv in zip(buf, pad_value_t):
elem = e[0]
if (not np.isscalar(elem)) and elem.shape != ():
max_len = max(map(len,
e)) if max_seqlen is None else max_seqlen
e = map(lambda i: np.pad(i, [0, max_len - len(i)], 'constant', constant_values=pv) if max_len >= len(i) else i[: max_len], e)
padded.append(np.stack(list(e)))
yield padded
return gen
class Dataset(object):
@classmethod
def from_generator_func(cls, gen, data_shapes=None, data_types=None):
if not inspect.isgeneratorfunction(gen):
raise ValueError('expect generator function, got %s' % repr(gen))
def wrapper(): #compat to py3.7
try:
for item in gen():
yield item
except RuntimeError as e:
if str(e) != 'generator raised StopIteration':
raise e
ret = cls()
ret.generator = wrapper
ret.data_shapes = data_shapes
ret.data_types = data_types
return ret
@classmethod
def from_file(cls, filename, format=None):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
def gen():
with open_file(filename, format) as f:
for line in f:
yield line
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_record_file(cls, filename):
if os.path.getsize(filename) == 0:
raise RuntimeError('%s is empty' % filename)
gen = open_record(filename)
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
@classmethod
def from_list(cls, ls):
if not isinstance(ls, list):
raise ValueError('expect list, got %s' % repr(ls))
def gen():
for i in ls:
yield i
ret = cls()
ret.generator = gen
ret.data_shapes = []
ret.data_types = str
return ret
def __init__(self):
self.name = None
self._data_shapes = None
self._data_types = None
self.generator = None
self.pyreader = None
def __repr__(self):
return 'Dataset: name: %s, data_shapes %s, data_types %s' % (
self.name, self._data_shapes, self._data_types)
def __eq__(self, other):
return self.name == other.name and \
self._data_shapes == other._data_shapes and \
self._data_types == other._data_types
def __iter__(self):
return self.generator()
#def __call__(self):
# return self.generator()
def _infer_shapes_and_types(self):
if self.generator is not None and self.name is not None:
log.info('Try to infer data shapes & types from generator')
first_value = next(self.generator())
shapes, types = [], []
for v in first_value:
if not isinstance(v, np.ndarray):
raise ValueError(
'dataset generator should use numpy elements, got %s' %
first_value)
shapes.append(v.shape)
types.append(v.dtype.name)
self._data_shapes = shapes
self._data_types = types
log.info('Dataset `%s` has data_shapes: %s data_types: %s' %
(self.name, repr(shapes), repr(types)))
else:
raise ValueError(
'Try to infer data shapes or types from incomplete Dataset')
@property
def data_shapes(self):
if self._data_shapes is None:
self._infer_shapes_and_types()
return self._data_shapes
else:
return self._data_shapes
@data_shapes.setter
def data_shapes(self, val):
self._data_shapes = val
@property
def data_types(self):
if self._data_types is None:
self._infer_shapes_and_types()
return self._data_types
else:
return self._data_types
@data_types.setter
def data_types(self, val):
self._data_types = val
def apply(self, transform_func):
#input_shapes = transform_func.input_shapes
#input_types = transform_func.input_types
#data_shapes = transform_func.data_shapes
#data_types = transform_func.data_types
#assert input_shapes == self._data_shapes
#assert input_types = self._data_types
ret_gen = transform_func(self.generator)
ret = type(self).from_generator_func(ret_gen)
if self.name is not None:
ret.name = self.name
#ret.data_shapes = data_shapes
#ret.data_types = data_types
return ret
def shuffle(self, buffer_size):
func = functools.partial(shuffle_func, buffer_size=buffer_size)
return self.apply(func)
def repeat(self, n=-1):
func = functools.partial(repeat_func, n=n)
return self.apply(func)
def map(self, fn):
func = functools.partial(map_func, fn=fn)
return self.apply(func)
def filter(self, fn):
func = functools.partial(filter_func, fn=fn)
return self.apply(func)
def shard(self, num_shards, index):
func = functools.partial(
shard_func, num_shards=num_shards, index=index)
return self.apply(func)
def interleave(self, map_fn, cycle_length, block_length):
func = functools.partial(
interleave_func,
map_fn=map_fn,
cycle_length=cycle_length,
block_length=block_length)
return self.apply(func)
def padded_batch(self, batch_size, pad_value=0, max_seqlen=None):
func = functools.partial(
padded_batch_func,
batch_size=batch_size,
pad_value=pad_value,
max_seqlen=max_seqlen)
return self.apply(func)
def take(self, count=1):
func = functools.partial(take_func, count=count)
return self.apply(func)
def buffered(self, size=10):
func = functools.partial(buffered_func, size=size)
return self.apply(func)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import six
from propeller.types import *
from propeller.util import ArgumentParser, parse_hparam, parse_runconfig, parse_file
from propeller.paddle import data
from propeller.paddle import train
from propeller.paddle.train import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
_global_collection = None
class Key(object):
"""predefine collection keys"""
SUMMARY_SCALAR = 1
SUMMARY_HISTOGRAM = 2
SKIP_OPTIMIZE = 3
class Collections(object):
"""global collections to record everything"""
def __init__(self):
self.col = {}
def __enter__(self):
global _global_collection
_global_collection = self
return self
def __exit__(self, err_type, err_value, trace):
global _global_collection
_global_collection = None
def add(self, key, val):
self.col.setdefault(key, []).append(val)
def get(self, key):
return self.col.get(key, None)
def default_collection():
global _global_collection
if _global_collection is None:
_global_collection = Collections()
return _global_collection
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
from propeller.paddle.data.functional import *
from propeller.paddle.data.feature_column import *
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "propeller/paddle/data/feature.proto";
package propeller;
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/example.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from propeller.paddle.data import feature_pb2 as propeller_dot_paddle_dot_data_dot_feature__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/example.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/example.proto\x12\tpropeller\x1a#propeller/paddle/data/feature.proto\"0\n\x07\x45xample\x12%\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x13.propeller.Features\"g\n\x0fSequenceExample\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.propeller.Features\x12.\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x17.propeller.FeatureListsb\x06proto3'
),
dependencies=[
propeller_dot_paddle_dot_data_dot_feature__pb2.DESCRIPTOR,
])
_EXAMPLE = _descriptor.Descriptor(
name='Example',
full_name='propeller.Example',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='features',
full_name='propeller.Example.features',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=87,
serialized_end=135, )
_SEQUENCEEXAMPLE = _descriptor.Descriptor(
name='SequenceExample',
full_name='propeller.SequenceExample',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='context',
full_name='propeller.SequenceExample.context',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='feature_lists',
full_name='propeller.SequenceExample.feature_lists',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=137,
serialized_end=240, )
_EXAMPLE.fields_by_name[
'features'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'context'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURES
_SEQUENCEEXAMPLE.fields_by_name[
'feature_lists'].message_type = propeller_dot_paddle_dot_data_dot_feature__pb2._FEATURELISTS
DESCRIPTOR.message_types_by_name['Example'] = _EXAMPLE
DESCRIPTOR.message_types_by_name['SequenceExample'] = _SEQUENCEEXAMPLE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Example = _reflection.GeneratedProtocolMessageType(
'Example',
(_message.Message, ),
dict(
DESCRIPTOR=_EXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.Example)
))
_sym_db.RegisterMessage(Example)
SequenceExample = _reflection.GeneratedProtocolMessageType(
'SequenceExample',
(_message.Message, ),
dict(
DESCRIPTOR=_SEQUENCEEXAMPLE,
__module__='propeller.paddle.data.example_pb2'
# @@protoc_insertion_point(class_scope:propeller.SequenceExample)
))
_sym_db.RegisterMessage(SequenceExample)
# @@protoc_insertion_point(module_scope)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package propeller;
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
map<string, Feature> feature = 1;
};
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
map<string, FeatureList> feature_list = 1;
};
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import struct
from six.moves import zip, map
import itertools
import gzip
from functools import partial
import multiprocessing
import six
import logging
import numpy as np
from glob import glob
from propeller.paddle.train import distribution
from propeller.data.functional import interleave_func
from propeller.paddle.data.functional import Dataset
from propeller.paddle.data import example_pb2, feature_pb2
log = logging.getLogger(__name__)
__all__ = [
'FeatureColumns', 'TextColumn', 'TextIDColumn', 'LabelColumn',
'basic_tokenizer', 'Column'
]
def basic_tokenizer(sen):
seg = sen.split(b' ')
seg = filter(lambda i: i != b' ', seg)
return seg
class Column():
def __init__(self, name):
pass
def raw_to_proto(self, raw):
return feature_pb2.Feature()
@property
def output_shapes(self):
pass
@property
def output_types(self):
pass
def proto_to_instance(self, proto):
raise NotImplementedError()
def raw_to_instance(self, raw):
raise NotImplementedError()
class LabelColumn(Column):
def __init__(self, name, vocab_dict=None, vocab_file=None):
self.name = name
self.vocab = None
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
if self.vocab is None:
ids = [int(raw)]
else:
ids = [self.vocab[raw]]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value[0], dtype=np.int64)
return ret
def raw_to_instance(self, raw):
if self.vocab is None:
ids = int(raw)
else:
ids = self.vocab[raw]
return ids
class TextColumn(Column):
def __init__(self,
name,
unk_id,
vocab_file=None,
vocab_dict=None,
tokenizer=basic_tokenizer):
self.name = name
self.tokenizer = tokenizer
self.unk_id = unk_id
if not (vocab_file or vocab_dict):
raise ValueError('at least specify vocab_file or vocab_dict')
if vocab_file:
self.vocab = {
j.strip(): i
for i, j in enumerate(open(vocab_file, 'rb').readlines())
}
if vocab_dict:
self.vocab = vocab_dict
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ids = [self.vocab.get(s, self.unk_id) for s in self.tokenizer(raw)]
return np.array(ids, dtype=np.int64)
class TextIDColumn(Column):
def __init__(self, name):
self.name = name
@property
def output_shapes(self):
return [-1]
@property
def output_types(self):
return 'int64'
def raw_to_proto(self, raw):
ids = [int(s) for s in raw.split(b' ')]
fe = feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=ids))
return fe
def proto_to_instance(self, feature):
ret = np.array(feature.int64_list.value, dtype=np.int64)
return ret
def raw_to_instance(self, raw):
ret = np.array([int(i) for i in raw.split(b' ')], dtype=np.int64)
return ret
class FeatureColumns(object):
def __init__(self, columns, pad_id=0):
self._columns = columns
def raw_files(self, raw_dir):
return [os.path.join(raw_dir, p) for p in os.listdir(raw_dir)]
def gz_files(self, gz_dir):
return None if gz_dir is None else [
os.path.join(gz_dir, p) for p in os.listdir(gz_dir)
]
def _make_gz_dataset(self, raw_dir, gz_dir):
assert raw_dir or gz_dir, 'data_dir not specified when using gz mode'
if raw_dir is not None:
assert os.path.exists(raw_dir), 'raw_dir not exists: %s' % raw_dir
raw_file = os.listdir(raw_dir)
if gz_dir is None:
gz_dir = '%s_gz' % raw_dir.rstrip('/')
if not os.path.exists(gz_dir):
os.mkdir(gz_dir)
if raw_dir is not None:
if len(raw_file) != 0:
log.debug('try making gz')
pool = multiprocessing.Pool()
args = [(os.path.join(raw_dir, f), os.path.join(gz_dir, f),
self._columns, b'\t') for f in raw_file]
pool.map(_make_gz, args)
pool.terminate()
else:
assert len(
os.listdir(gz_dir)
) != 0, 'cant find gz file or raw-txt file at [%s] and [%s]' % (
raw_dir, gz_dir)
return gz_dir
def _read_gz_dataset(self,
gz_files,
shuffle=False,
repeat=True,
shard=False,
**kwargs):
if len(gz_files) == 0:
raise ValueError('reading gz from empty file list: %s' % gz_files)
log.info('reading gz from %s' % '\n'.join(gz_files))
dataset = Dataset.from_list(gz_files)
if repeat:
dataset = dataset.repeat()
if shard and distribution.status.mode == distribution.DistributionMode.NCCL:
log.info('Apply dataset sharding in distribution env')
train_ds = train_ds.shard(distribution.status.num_replica,
distribution.status.replica_id)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(gz_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_record_file(filename),
cycle_length=len(gz_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_gz(record_str): # function that takes python_str as input
ex = example_pb2.Example()
ex.ParseFromString(record_str)
ret = []
fea_dict = ex.features.feature
for c in self._columns:
ins = c.proto_to_instance(fea_dict[c.name])
ret.append(ins)
return ret
dataset = dataset.map(_parse_gz)
return dataset
def _read_txt_dataset(self,
data_files,
shuffle=False,
repeat=True,
**kwargs):
log.info('reading raw files from %s' % '\n'.join(data_files))
dataset = Dataset.from_list(data_files)
if repeat:
dataset = dataset.repeat()
if shuffle:
dataset = dataset.shuffle(buffer_size=len(data_files))
fn = partial(
interleave_func,
map_fn=lambda filename: Dataset.from_file(filename),
cycle_length=len(data_files),
block_length=1)
dataset = dataset.apply(fn)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_txt_file(
record_str): # function that takes python_str as input
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_txt_file)
return dataset
def _read_stdin_dataset(self, encoding='utf8', shuffle=False, **kwargs):
log.info('reading raw files stdin')
def gen():
if six.PY3:
source = sys.stdin.buffer
else:
source = sys.stdin
while True:
line = source.readline()
if len(line) == 0:
break
yield line,
dataset = Dataset.from_generator_func(gen)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
def _parse_stdin(record_str):
'''function that takes python_str as input'''
features = record_str.strip(b'\n').split(b'\t')
ret = [
column.raw_to_instance(feature)
for feature, column in zip(features, self._columns)
]
return ret
dataset = dataset.map(_parse_stdin)
return dataset
def _prepare_dataset(self,
dataset,
map_func_before_batch=None,
map_func_after_batch=None,
shuffle_buffer_size=None,
batch_size=1,
pad_id=0,
prefetch=None,
**kwargs):
if map_func_before_batch is not None:
dataset = dataset.map(map_func_before_batch)
if batch_size:
dataset = dataset.padded_batch(batch_size, pad_id)
if map_func_after_batch is not None:
dataset = dataset.map(map_func_after_batch)
return dataset
def build_dataset(self,
name,
use_gz=True,
data_dir=None,
gz_dir=None,
data_file=None,
**kwargs):
if use_gz:
gz_dir = self._make_gz_dataset(data_dir, gz_dir)
gz_files = self.gz_files(gz_dir)
ds = self._read_gz_dataset(gz_files, **kwargs)
else:
if data_dir is not None:
data_files = self.raw_files(data_dir)
elif data_file is not None:
data_files = [data_file]
else:
raise ValueError('data_dir or data_files not specified')
ds = self._read_txt_dataset(data_files, **kwargs)
ds.name = name
return ds
def build_dataset_from_stdin(self, name, **kwargs):
ds = self._read_stdin_dataset(**kwargs)
ds.name = name
return ds
def _make_gz(args):
try:
from_file, to_file, columns, sep = args
if os.path.exists(to_file):
return
with open(from_file, 'rb') as fin, gzip.open(to_file, 'wb') as fout:
log.debug('making gz %s => %s' % (from_file, to_file))
for i, line in enumerate(fin):
line = line.strip(b'\n').split(sep)
#if i % 10000 == 0:
# log.debug('making gz %s => %s [%d]' % (from_file, to_file, i))
if len(line) != len(columns):
log.error('columns not match at %s, got %d, expect %d' %
(from_file, len(line), len(columns)))
continue
features = {}
for l, c in zip(line, columns):
features[c.name] = c.raw_to_proto(l)
example = example_pb2.Example(features=feature_pb2.Features(
feature=features))
serialized = example.SerializeToString()
l = len(serialized)
data = struct.pack('i%ds' % l, l, serialized)
fout.write(data)
log.debug('done making gz %s => %s' % (from_file, to_file))
except Exception as e:
log.exception(e)
raise e
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: propeller/paddle/data/feature.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (
lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='propeller/paddle/data/feature.proto',
package='propeller',
syntax='proto3',
serialized_options=None,
serialized_pb=_b(
'\n#propeller/paddle/data/feature.proto\x12\tpropeller\"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c\"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01\"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01\"\x95\x01\n\x07\x46\x65\x61ture\x12*\n\nbytes_list\x18\x01 \x01(\x0b\x32\x14.propeller.BytesListH\x00\x12*\n\nfloat_list\x18\x02 \x01(\x0b\x32\x14.propeller.FloatListH\x00\x12*\n\nint64_list\x18\x03 \x01(\x0b\x32\x14.propeller.Int64ListH\x00\x42\x06\n\x04kind\"\x81\x01\n\x08\x46\x65\x61tures\x12\x31\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32 .propeller.Features.FeatureEntry\x1a\x42\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.propeller.Feature:\x02\x38\x01\"2\n\x0b\x46\x65\x61tureList\x12#\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x12.propeller.Feature\"\x9a\x01\n\x0c\x46\x65\x61tureLists\x12>\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32(.propeller.FeatureLists.FeatureListEntry\x1aJ\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.propeller.FeatureList:\x02\x38\x01\x62\x06proto3'
))
_BYTESLIST = _descriptor.Descriptor(
name='BytesList',
full_name='propeller.BytesList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.BytesList.value',
index=0,
number=1,
type=12,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=50,
serialized_end=76, )
_FLOATLIST = _descriptor.Descriptor(
name='FloatList',
full_name='propeller.FloatList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FloatList.value',
index=0,
number=1,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=78,
serialized_end=108, )
_INT64LIST = _descriptor.Descriptor(
name='Int64List',
full_name='propeller.Int64List',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Int64List.value',
index=0,
number=1,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=_b('\020\001'),
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=110,
serialized_end=140, )
_FEATURE = _descriptor.Descriptor(
name='Feature',
full_name='propeller.Feature',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='bytes_list',
full_name='propeller.Feature.bytes_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='float_list',
full_name='propeller.Feature.float_list',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='int64_list',
full_name='propeller.Feature.int64_list',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='kind',
full_name='propeller.Feature.kind',
index=0,
containing_type=None,
fields=[]),
],
serialized_start=143,
serialized_end=292, )
_FEATURES_FEATUREENTRY = _descriptor.Descriptor(
name='FeatureEntry',
full_name='propeller.Features.FeatureEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.Features.FeatureEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.Features.FeatureEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=358,
serialized_end=424, )
_FEATURES = _descriptor.Descriptor(
name='Features',
full_name='propeller.Features',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.Features.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURES_FEATUREENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=295,
serialized_end=424, )
_FEATURELIST = _descriptor.Descriptor(
name='FeatureList',
full_name='propeller.FeatureList',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature',
full_name='propeller.FeatureList.feature',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=426,
serialized_end=476, )
_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor(
name='FeatureListEntry',
full_name='propeller.FeatureLists.FeatureListEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='propeller.FeatureLists.FeatureListEntry.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value',
full_name='propeller.FeatureLists.FeatureListEntry.value',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=_b('8\001'),
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=559,
serialized_end=633, )
_FEATURELISTS = _descriptor.Descriptor(
name='FeatureLists',
full_name='propeller.FeatureLists',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='feature_list',
full_name='propeller.FeatureLists.feature_list',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR),
],
extensions=[],
nested_types=[_FEATURELISTS_FEATURELISTENTRY, ],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=479,
serialized_end=633, )
_FEATURE.fields_by_name['bytes_list'].message_type = _BYTESLIST
_FEATURE.fields_by_name['float_list'].message_type = _FLOATLIST
_FEATURE.fields_by_name['int64_list'].message_type = _INT64LIST
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'bytes_list'])
_FEATURE.fields_by_name[
'bytes_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'float_list'])
_FEATURE.fields_by_name[
'float_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURE.oneofs_by_name['kind'].fields.append(_FEATURE.fields_by_name[
'int64_list'])
_FEATURE.fields_by_name[
'int64_list'].containing_oneof = _FEATURE.oneofs_by_name['kind']
_FEATURES_FEATUREENTRY.fields_by_name['value'].message_type = _FEATURE
_FEATURES_FEATUREENTRY.containing_type = _FEATURES
_FEATURES.fields_by_name['feature'].message_type = _FEATURES_FEATUREENTRY
_FEATURELIST.fields_by_name['feature'].message_type = _FEATURE
_FEATURELISTS_FEATURELISTENTRY.fields_by_name[
'value'].message_type = _FEATURELIST
_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS
_FEATURELISTS.fields_by_name[
'feature_list'].message_type = _FEATURELISTS_FEATURELISTENTRY
DESCRIPTOR.message_types_by_name['BytesList'] = _BYTESLIST
DESCRIPTOR.message_types_by_name['FloatList'] = _FLOATLIST
DESCRIPTOR.message_types_by_name['Int64List'] = _INT64LIST
DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE
DESCRIPTOR.message_types_by_name['Features'] = _FEATURES
DESCRIPTOR.message_types_by_name['FeatureList'] = _FEATURELIST
DESCRIPTOR.message_types_by_name['FeatureLists'] = _FEATURELISTS
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
BytesList = _reflection.GeneratedProtocolMessageType(
'BytesList',
(_message.Message, ),
dict(
DESCRIPTOR=_BYTESLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.BytesList)
))
_sym_db.RegisterMessage(BytesList)
FloatList = _reflection.GeneratedProtocolMessageType(
'FloatList',
(_message.Message, ),
dict(
DESCRIPTOR=_FLOATLIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FloatList)
))
_sym_db.RegisterMessage(FloatList)
Int64List = _reflection.GeneratedProtocolMessageType(
'Int64List',
(_message.Message, ),
dict(
DESCRIPTOR=_INT64LIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Int64List)
))
_sym_db.RegisterMessage(Int64List)
Feature = _reflection.GeneratedProtocolMessageType(
'Feature',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURE,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Feature)
))
_sym_db.RegisterMessage(Feature)
Features = _reflection.GeneratedProtocolMessageType(
'Features',
(_message.Message, ),
dict(
FeatureEntry=_reflection.GeneratedProtocolMessageType(
'FeatureEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURES_FEATUREENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features.FeatureEntry)
)),
DESCRIPTOR=_FEATURES,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.Features)
))
_sym_db.RegisterMessage(Features)
_sym_db.RegisterMessage(Features.FeatureEntry)
FeatureList = _reflection.GeneratedProtocolMessageType(
'FeatureList',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELIST,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureList)
))
_sym_db.RegisterMessage(FeatureList)
FeatureLists = _reflection.GeneratedProtocolMessageType(
'FeatureLists',
(_message.Message, ),
dict(
FeatureListEntry=_reflection.GeneratedProtocolMessageType(
'FeatureListEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists.FeatureListEntry)
)),
DESCRIPTOR=_FEATURELISTS,
__module__='propeller.paddle.data.feature_pb2'
# @@protoc_insertion_point(class_scope:propeller.FeatureLists)
))
_sym_db.RegisterMessage(FeatureLists)
_sym_db.RegisterMessage(FeatureLists.FeatureListEntry)
_FLOATLIST.fields_by_name['value']._options = None
_INT64LIST.fields_by_name['value']._options = None
_FEATURES_FEATUREENTRY._options = None
_FEATURELISTS_FEATURELISTENTRY._options = None
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.data.functional import Dataset as DatasetBase
log = logging.getLogger(__name__)
class Dataset(DatasetBase):
def placeholders(self):
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
ret = []
for i, (shape,
types) in enumerate(zip(self.data_shapes, self.data_types)):
ret.append(
L.data(
'%s_placeholder_%d' % (self.name, i),
shape=shape,
append_batch_size=False,
dtype=types))
return ret
def features(self):
'''start point of net building. call this in a program scope'''
if self.name is None:
raise ValueError('can not get feature from unnamed Dataset')
if len(self.data_shapes) != len(self.data_types):
raise ValueError(
'Dataset shapes and types not match: shape:%s types%s' %
(repr(self._data_shapes), repr(self._data_types)))
return self.placeholders()
def start(self, places=F.cuda_places()):
#assert self.pyreader is not None, 'use Dataset.features to build net first, then start dataset'
def gen():
try:
for idx, i in enumerate(self.generator()):
yield i
except Exception as e:
log.exception(e)
raise e
r = F.io.PyReader(
feed_list=self.placeholders(), capacity=50, iterable=True)
r.decorate_batch_generator(gen, places=places)
return r()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import logging
import six
import asyncio
import threading
import grpc
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import propeller.paddle.service.utils as serv_utils
from concurrent.futures import ThreadPoolExecutor
import paddle.fluid as F
from time import sleep, time
log = logging.getLogger(__name__)
def profile(msg):
def decfn(fn):
def retfn(*args, **kwargs):
start = time()
ret = fn(*args, **kwargs)
end = time()
log.debug('%s timecost: %.5f' % (msg, end - start))
return ret
return retfn
return decfn
def serve(model_dir, host, num_concurrent=None):
if six.PY2:
raise RuntimeError('propeller service work in python3 only')
num_worker = len(F.cuda_places(
)) if num_concurrent is None else num_concurrent
pool = ThreadPoolExecutor(num_worker)
class Predictor(object):
def __init__(self, did):
log.debug('create predictor on card %d' % did)
config = F.core.AnalysisConfig(model_dir)
config.enable_use_gpu(5000, did)
self._predictor = F.core.create_paddle_predictor(config)
@profile('paddle')
def __call__(self, args):
for i, a in enumerate(args):
a.name = 'placeholder_%d' % i
res = self._predictor.run(args)
return res
predictor_context = {}
class InferenceService(interface_pb2_grpc.InferenceServicer):
@profile('service')
def Infer(self, request, context):
try:
slots = request.slots
current_thread = threading.current_thread()
log.debug('%d slots received dispatch to thread %s' %
(len(slots), current_thread))
if current_thread not in predictor_context:
did = list(pool._threads).index(current_thread)
log.debug('spawning worker thread %d' % did)
predictor = Predictor(did)
predictor_context[current_thread] = predictor
else:
predictor = predictor_context[current_thread]
slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
ret = predictor(slots)
response = [serv_utils.paddlearray_to_slot(r) for r in ret]
except Exception as e:
log.exception(e)
raise e
return interface_pb2.Slots(slots=response)
server = grpc.server(pool)
interface_pb2_grpc.add_InferenceServicer_to_server(InferenceService(),
server)
server.add_insecure_port(host)
server.start()
log.info('server started on %s...' % host)
try:
while True:
sleep(100000)
except KeyboardInterrupt as e:
pass
log.info('server stoped...')
if __name__ == '__main__':
from propeller import log
log.setLevel(logging.DEBUG)
serve(
'/home/work/chenxuyi/playground/grpc_play/ernie2.0/',
'10.255.138.19:8334',
num_concurrent=3)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import struct
from propeller.service import interface_pb2
from propeller.service import interface_pb2_grpc
import paddle.fluid.core as core
def slot_to_paddlearray(slot):
if slot.type == interface_pb2.Slot.FP32:
type_str = 'f'
dtype = core.PaddleDType.FLOAT32
elif slot.type == interface_pb2.Slot.INT32:
type_str = 'i'
dtype = core.PaddleDType.INT32
elif slot.type == interface_pb2.Slot.INT64:
type_str = 'q'
dtype = core.PaddleDType.INT64
else:
raise RuntimeError('know type %s' % slot.type)
ret = core.PaddleTensor()
ret.shape = slot.dims
ret.dtype = dtype
num = len(slot.data) // struct.calcsize(type_str)
arr = struct.unpack('%d%s' % (num, type_str), slot.data)
ret.data = core.PaddleBuf(arr)
return ret
def paddlearray_to_slot(arr):
if arr.dtype == core.PaddleDType.FLOAT32:
dtype = interface_pb2.Slot.FP32
type_str = 'f'
arr_data = arr.data.float_data()
elif arr.dtype == core.PaddleDType.INT32:
dtype = interface_pb2.Slot.INT32
type_str = 'i'
arr_data = arr.data.int32_data()
elif arr.dtype == core.PaddleDType.INT64:
dtype = interface_pb2.Slot.INT64
type_str = 'q'
arr_data = arr.data.int64_data()
else:
raise RuntimeError('know type %s' % arr.dtype)
data = struct.pack('%d%s' % (len(arr_data), type_str), *arr_data)
pb = interface_pb2.Slot(type=dtype, dims=list(arr.shape), data=data)
return pb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import paddle.fluid as F
from propeller.paddle.collection import default_collection, Key
def scalar(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_SCALAR, (name, tensor))
def histogram(name, tensor):
if not isinstance(tensor, F.framework.Variable):
raise ValueError('expect paddle Variable, got %s' % repr(tensor))
tensor.persistable = True
default_collection().add(Key.SUMMARY_HISTOGRAM, (name, tensor))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import sys
import logging
from time import time
log = logging.getLogger(__name__)
from propeller.paddle.train.monitored_executor import *
from propeller.paddle.train.trainer import *
from propeller.paddle.train.hooks import *
from propeller.train.model import Model
from propeller.paddle.train import exporter
from propeller.paddle.train import distribution
from propeller.paddle.train import metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import functools
import six
import logging
from time import sleep
import paddle.fluid as F
import paddle.fluid.layers as L
log = logging.getLogger(__name__)
import propeller.util
__all__ = ['init_distribuition_env', 'status']
status = None
class DistributionMode(object):
LOCAL = 0
NCCL = 1
class DistributionStatus(object):
def __init__(self, config):
if config is None:
self._mode = DistributionMode.LOCAL
self._env = None
self._this = None
else:
try:
self._mode = DistributionMode.NCCL
cluster = config['cluster']
task = config['task']['type']
idx = int(config['task']['index'])
self._this = cluster[task][idx]
self._env = cluster['chief'] + cluster['worker']
if len(set(self._env)) != len(self._env):
raise ValueError('duplicate host in dis_config %s' %
config)
except KeyError as e:
raise ValueError(
'PROPELLER_DISCONFIG wrong: %s not found in %s' %
(e, repr(dis_config)))
@property
def mode(self):
return self._mode
@property
def num_replica(self):
if self._mode == DistributionMode.LOCAL:
return 1
elif self._mode == DistributionMode.NCCL:
return len(self._env)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def replica_id(self):
if self._mode == DistributionMode.LOCAL:
return 0
elif self._mode == DistributionMode.NCCL:
return self._env.index(self._this)
else:
raise ValueError('Got unknow distribution mode %s' %
repr(self._mode))
@property
def is_master(self):
if self._mode == DistributionMode.LOCAL:
return True
elif self._mode == DistributionMode.NCCL:
return self.replica_id == 0
else:
raise ValueError('got unknow distribution mode %s' %
repr(self._mode))
dis_config = propeller.util._get_dict_from_environ_or_json_or_file(
None, 'PROPELLER_DISCONFIG')
status = DistributionStatus(dis_config)
def run_on_master(func):
"""skip function in distribution env"""
@functools.wraps(func)
def f(*arg, **kwargs):
"""f"""
if status is None:
raise ValueError('distribution mode unkown at this point')
if status.mode == DistributionMode.LOCAL:
r = func(*arg, **kwargs)
elif status.mode == DistributionMode.NCCL:
if status.is_master:
r = func(*arg, **kwargs)
else:
r = 0 # skip function
#MPI.COMM_WORLD.Barrier()
return r
return f
def init_distribuition_env(program):
if status.mode == DistributionMode.LOCAL:
log.info('Initializing local training')
elif status.mode == DistributionMode.NCCL:
config = F.DistributeTranspilerConfig()
config.mode = "nccl2"
F.DistributeTranspiler(config=config).transpile(
status.replica_id,
trainers=','.join(status._env),
current_endpoint=status._this,
program=program.train_program,
startup_program=program.startup_program)
log.info('Initializing distribution training with config %s' %
(repr(dis_config)))
if status.is_master:
sleep(30)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import os
import itertools
import six
import abc
import logging
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller.paddle.train import Saver
from propeller.types import InferenceSpec
log = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class Exporter():
@abc.abstractmethod
def export(self, exe, program, eval_result, state):
raise NotImplementedError()
class BestExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
eval_program = program.train_program
# FIXME: all eval datasets has same name/types/shapes now!!! so every eval program are the smae
saver = Saver(
self._export_dir,
exe,
program=eval_program,
max_ckpt_to_keep=1)
saver.save(state)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
class BestInferenceModelExporter(Exporter):
def __init__(self, export_dir, cmp_fn):
self._export_dir = export_dir
self._best = None
self.cmp_fn = cmp_fn
def export(self, exe, program, eval_model_spec, eval_result, state):
log.debug('New evaluate result: %s \nold: %s' %
(repr(eval_result), repr(self._best)))
if self._best is None or self.cmp_fn(old=self._best, new=eval_result):
log.debug('[Best Exporter]: export to %s' % self._export_dir)
if eval_model_spec.inference_spec is None:
raise ValueError('model_fn didnt return InferenceSpec')
inf_sepc_dict = eval_model_spec.inference_spec
if not isinstance(inf_sepc_dict, dict):
inf_sepc_dict = {'inference': inf_sepc_dict}
for inf_sepc_name, inf_sepc in six.iteritems(inf_sepc_dict):
if not isinstance(inf_sepc, InferenceSpec):
raise ValueError('unkonw inference spec type: %s' % v)
save_dir = os.path.join(self._export_dir, inf_sepc_name)
log.debug('[Best Exporter]: save inference model: "%s" to %s' %
(inf_sepc_name, save_dir))
feed_var = [i.name for i in inf_sepc.inputs]
fetch_var = inf_sepc.outputs
eval_program = program.train_program
startup_prog = F.Program()
F.io.save_inference_model(
save_dir,
feed_var,
fetch_var,
exe,
main_program=eval_program)
self._best = eval_result
else:
log.debug('[Best Exporter]: skip step %s' % state.gstep)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import sys
import six
import os
import itertools
import numpy as np
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
from propeller import util
from propeller.paddle.train import distribution
from propeller.paddle.train.metrics import Metrics
__all__ = [
'RunHook', 'TqdmProgressBarHook', 'TqdmNotebookProgressBarHook',
'CheckpointSaverHook', 'LoggingHook', 'StopAtStepHook', 'EvalHook'
]
log = logging.getLogger(__name__)
class RunHook(object):
def __init__(self):
pass
def before_train(self):
pass
def before_run(self, state):
return []
def after_run(self, res_list, state):
pass
def should_stop(self, state):
return False
def after_train(self):
pass
class TqdmProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class TqdmNotebookProgressBarHook(RunHook):
def __init__(self, max_steps, desc=None):
self.tqdm = None
import tqdm
from propeller import log as main_log
hdl = main_log.handlers[0]
class TqdmLogginHandler(logging.Handler):
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg, file=sys.stderr)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
tqdm_hdl = TqdmLogginHandler()
tqdm_hdl.setFormatter(hdl.formatter)
main_log.removeHandler(hdl)
main_log.addHandler(tqdm_hdl)
self.tqdm = tqdm.tqdm_notebook(total=max_steps, desc=None)
def before_run(self, state):
self.tqdm.n = state.gstep
self.tqdm.refresh()
return []
def __del__(self):
if self.tqdm:
self.tqdm.close()
class LoggingHook(RunHook):
def __init__(self,
loss,
per_step=10,
skip_step=100,
summary_writer=None,
summary_record=None):
if per_step is None or skip_step is None:
raise ValueError('wrong step argument, per step: %d skip_step %d' %
(per_step, skip_step))
self.loss = loss
self.per_step = per_step
self.skip_step = skip_step
self.summary_record = summary_record
self.writer = summary_writer
self.last_state = None
def before_train(self):
if self.summary_record:
if self.summary_record.scalar:
self.s_name, self.s_tolog = zip(*self.summary_record.scalar)
else:
self.s_name, self.s_tolog = [], []
if self.summary_record.histogram:
self.h_name, self.h_tolog = zip(*self.summary_record.histogram)
else:
self.h_name, self.h_tolog = [], []
def before_run(self, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
ret = [self.loss]
if self.summary_record:
ret += self.s_tolog
ret += self.h_tolog
return ret
else:
return []
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and state.step > self.skip_step:
if not self.summary_record:
return
loss = float(res_list[0])
s_np = res_list[1:1 + len(self.s_name)]
h_np = res_list[1 + len(self.s_name):1 + len(self.s_name) + len(
self.h_name)]
if self.last_state is not None:
speed = (state.gstep - self.last_state.gstep) / (
state.time - self.last_state.time)
else:
speed = -1.
self.last_state = state
# log to tensorboard
if self.writer is not None:
self.writer.add_scalar('loss', loss, state.gstep)
for name, t in zip(self.s_name, s_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_scalar(name, t, state.gstep)
for name, t in zip(self.h_name, h_np):
if np.isnan(t).any():
log.warning('Nan summary: %s, skip' % name)
else:
self.writer.add_histogram(name, t, state.gstep)
if speed > 0.:
self.writer.add_scalar('global_step', speed, state.gstep)
# log to stdout
log.debug('\t'.join([
'step: %d' % state.gstep,
'steps/sec: %.5f' % speed,
'loss: %.5f' % loss,
'' if self.summary_record is None else ' '.join(
map(lambda t: '%s:%s' % t, zip(self.s_name, s_np))),
]))
class StopAtStepHook(RunHook):
def __init__(self, stop_global_step, stop_step):
self._stop_gstep = stop_global_step
self._stop_step = stop_step
def should_stop(self, state):
if (self._stop_gstep and state.gstep >= self._stop_gstep) or \
(self._stop_step and state.step >= self._stop_step):
log.info('StopAtStepHook called stop')
return True
else:
return False
class EvalHook(RunHook):
"""hook this on a eval Executor"""
def __init__(self, metrics, summary_writer=None):
self.writer = summary_writer
self._result = None
if not isinstance(metrics, dict):
raise ValueError('metrics should be dict, got %s' % repr(metrics))
for k, m in six.iteritems(metrics):
if not isinstance(m, Metrics):
raise ValueError(
'metrics %s should be instance of propeller.Metrics, got %s'
% (k, repr(m)))
if len(metrics):
self.names = list(metrics.keys())
self.metrics = list(metrics.values())
else:
self.names, self.metrics = [], []
def before_train(self):
for m in self.metrics:
m.reset()
def before_run(self, state):
ls = [m.tensor for m in self.metrics]
for i in ls:
if not (isinstance(i, list) or isinstance(i, tuple)):
raise ValueError(
'metrics should return tuple or list of tensors, got %s' %
repr(i))
for ii in i:
if not isinstance(ii, F.framework.Variable):
raise ValueError(
'metrics tensor be propeller.train.Metrics, got %s of type %s'
% (repr(ii), type(ii)))
ls_flt, self.schema = util.flatten(ls)
#log.debug(ls_flt)
return ls_flt
def after_run(self, res_list, state):
res = util.unflatten(res_list, self.schema)
for r, m in zip(res, self.metrics):
m.update(r)
@property
def result(self):
return self._result
def after_train(self):
printable = []
self._result = {}
for n, m in zip(self.names, self.metrics):
val = m.eval()
self._result[n] = val
return self.result
class CheckpointSaverHook(RunHook):
def __init__(self, saver, per_step=10, skip_step=100):
self.saver = saver
self.per_step = per_step
self.skip_step = skip_step
def after_run(self, res_list, state):
if state.gstep % self.per_step == 0 and \
state.step > self.skip_step:
self.saver.save(state)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import itertools
import logging
import paddle.fluid as F
import paddle.fluid.layers as L
import sklearn.metrics
log = logging.getLogger(__name__)
__all__ = [
'Metrics', 'F1', 'Recall', 'Precision', 'Mrr', 'Mean', 'Acc', 'ChunkF1',
'RecallAtPrecision'
]
class Metrics(object):
def __init__(self):
self.saver = []
@property
def tensor(self):
pass
def update(self, *args):
pass
def eval(self):
pass
class Mean(Metrics):
def __init__(self, t):
self.t = t
self.reset()
def reset(self):
self.saver = np.array([])
@property
def tensor(self):
self.t.persistable = True
return self.t,
def update(self, args):
t, = args
t = t.reshape([-1])
self.saver = np.concatenate([self.saver, t])
def eval(self):
return self.saver.mean()
class Ppl(Mean):
def eval(self):
return np.exp(self.saver.mean())
class Acc(Mean):
def __init__(self, label, pred):
self.eq = L.equal(pred, label)
self.reset()
@property
def tensor(self):
self.eq.persistable = True
return self.eq,
class MSE(Mean):
def __init__(self, label, pred):
diff = pred - label
self.mse = diff * diff
self.reset()
@property
def tensor(self):
self.mse.persistable = True
return self.mse,
class Cosine(Mean):
def __init__(self, label, pred):
self.cos = L.cos_sim(label, pred)
self.reset()
@property
def tensor(self):
self.cos.persistable = True
return self.cos,
class Precision(Metrics):
def __init__(self, label, pred):
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.label_saver = np.array([], dtype=np.bool)
self.pred_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.label.persistable = True
self.pred.persistable = True
return self.label, self.pred
def update(self, args):
label, pred = args
label = label.reshape([-1]).astype(np.bool)
pred = pred.reshape([-1]).astype(np.bool)
if label.shape != pred.shape:
raise ValueError(
'Metrics precesion: input not match: label:%s pred:%s' %
(label, pred))
self.label_saver = np.concatenate([self.label_saver, label])
self.pred_saver = np.concatenate([self.pred_saver, pred])
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
return tp / t
class Recall(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
p = (self.label_saver).astype(np.int64).sum()
return tp / p
class F1(Precision):
def eval(self):
tp = (self.label_saver & self.pred_saver).astype(np.int64).sum()
t = self.label_saver.astype(np.int64).sum()
p = self.pred_saver.astype(np.int64).sum()
precision = tp / (t + 1.e-6)
recall = tp / (p + 1.e-6)
return 2 * precision * recall / (precision + recall + 1.e-6)
class Auc(Metrics):
def __init__(self, label, pred):
self.pred = pred
self.label = label
self.reset()
def reset(self):
self.pred_saver = np.array([], dtype=np.float32)
self.label_saver = np.array([], dtype=np.bool)
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
return [self.pred, self.label]
def update(self, args):
pred, label = args
pred = pred.reshape([-1]).astype(np.float32)
label = label.reshape([-1]).astype(np.bool)
self.pred_saver = np.concatenate([self.pred_saver, pred])
self.label_saver = np.concatenate([self.label_saver, label])
def eval(self):
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_saver.astype(np.int64), self.pred_saver)
auc = sklearn.metrics.auc(fpr, tpr)
return auc
class RecallAtPrecision(Auc):
def __init__(self, label, pred, precision=0.9):
super(RecallAtPrecision, self).__init__(label, pred)
self.precision = precision
def eval(self):
self.pred_saver = self.pred_saver.reshape(
[self.label_saver.size, -1])[:, -1]
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
self.label_saver, self.pred_saver)
for p, r in zip(precision, recall):
if p > self.precision:
return r
class PrecisionAtThreshold(Auc):
def __init__(self, label, pred, threshold=0.5):
super().__init__(label, pred)
self.threshold = threshold
def eval(self):
infered = self.pred_saver > self.threshold
correct_num = np.array(infered & self.label_saver).sum()
infer_num = infered.sum()
return correct_num / (infer_num + 1.e-6)
class Mrr(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.reset()
def reset(self):
self.qid_saver = np.array([], dtype=np.int64)
self.label_saver = np.array([], dtype=np.int64)
self.pred_saver = np.array([], dtype=np.float32)
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError(
'Mrr dimention not match: qid[%s] label[%s], pred[%s]' %
(qid.shape, label.shape, pred.shape))
self.qid_saver = np.concatenate(
[self.qid_saver, qid.reshape([-1]).astype(np.int64)])
self.label_saver = np.concatenate(
[self.label_saver, label.reshape([-1]).astype(np.int64)])
self.pred_saver = np.concatenate(
[self.pred_saver, pred.reshape([-1]).astype(np.float32)])
def eval(self):
def key_func(tup):
return tup[0]
def calc_func(tup):
ranks = [
1. / (rank + 1.)
for rank, (_, l, p) in enumerate(
sorted(
tup, key=lambda t: t[2], reverse=True)) if l != 0
]
ranks = ranks[0]
return ranks
mrr_for_qid = [
calc_func(tup)
for _, tup in itertools.groupby(
sorted(
zip(self.qid_saver, self.label_saver, self.pred_saver),
key=key_func),
key=key_func)
]
mrr = np.float32(sum(mrr_for_qid) / len(mrr_for_qid))
return mrr
class ChunkF1(Metrics):
def __init__(self, label, pred, seqlen, num_label):
self.label = label
self.pred = pred
self.seqlen = seqlen
self.null_index = num_label - 1
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
def _extract_bio_chunk(self, seq):
chunks = []
cur_chunk = None
for index in range(len(seq)):
tag = seq[index]
tag_type = tag // 2
tag_pos = tag % 2
if tag == self.null_index:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = None
continue
if tag_pos == 0:
if cur_chunk is not None:
chunks.append(cur_chunk)
cur_chunk = {}
cur_chunk = {"st": index, "en": index + 1, "type": tag_type}
else:
if cur_chunk is None:
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
continue
if cur_chunk["type"] == tag_type:
cur_chunk["en"] = index + 1
else:
chunks.append(cur_chunk)
cur_chunk = {
"st": index,
"en": index + 1,
"type": tag_type
}
if cur_chunk is not None:
chunks.append(cur_chunk)
return chunks
def reset(self):
self.label_cnt = 0
self.pred_cnt = 0
self.correct_cnt = 0
@property
def tensor(self):
self.pred.persistable = True
self.label.persistable = True
self.seqlen.persistable = True
return [self.pred, self.label, self.seqlen]
def update(self, args):
pred, label, seqlen = args
pred = pred.reshape([-1]).astype(np.int32).tolist()
label = label.reshape([-1]).astype(np.int32).tolist()
seqlen = seqlen.reshape([-1]).astype(np.int32).tolist()
max_len = 0
for l in seqlen:
max_len = max(max_len, l)
for i in range(len(seqlen)):
seq_st = i * max_len + 1
seq_en = seq_st + (seqlen[i] - 2)
pred_chunks = self._extract_bio_chunk(pred[seq_st:seq_en])
label_chunks = self._extract_bio_chunk(label[seq_st:seq_en])
self.pred_cnt += len(pred_chunks)
self.label_cnt += len(label_chunks)
pred_index = 0
label_index = 0
while label_index < len(label_chunks) and pred_index < len(
pred_chunks):
if pred_chunks[pred_index]['st'] < label_chunks[label_index][
'st']:
pred_index += 1
elif pred_chunks[pred_index]['st'] > label_chunks[label_index][
'st']:
label_index += 1
else:
if pred_chunks[pred_index]['en'] == label_chunks[label_index]['en'] \
and pred_chunks[pred_index]['type'] == label_chunks[label_index]['type']:
self.correct_cnt += 1
pred_index += 1
label_index += 1
def eval(self):
if self.pred_cnt == 0:
precision = 0.0
else:
precision = 1.0 * self.correct_cnt / self.pred_cnt
if self.label_cnt == 0:
recall = 0.0
else:
recall = 1.0 * self.correct_cnt / self.label_cnt
if self.correct_cnt == 0:
f1 = 0.0
else:
f1 = 2 * precision * recall / (precision + recall)
return np.float32(f1)
class PNRatio(Metrics):
def __init__(self, qid, label, pred):
self.qid = qid
self.label = label
self.pred = pred
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
for i in range(0, len(outputs)):
l1, p1 = outputs[i]
for j in range(i + 1, len(outputs)):
l2, p2 = outputs[j]
if l1 > l2:
if p1 > p2:
p += 1
elif p1 < p2:
n += 1
elif l1 < l2:
if p1 < p2:
p += 1
elif p1 > p2:
n += 1
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class BinaryPNRatio(PNRatio):
def __init__(self, qid, label, pred):
super(BinaryPNRatio, self).__init__(qid, label, pred)
def eval(self):
p = 0
n = 0
for qid, outputs in self.saver.items():
pos_set = []
neg_set = []
for label, score in outputs:
if label == 1:
pos_set.append(score)
else:
neg_set.append(score)
for ps in pos_set:
for ns in neg_set:
if ps > ns:
p += 1
elif ps < ns:
n += 1
else:
continue
pn = p / n if n > 0 else 0.0
return np.float32(pn)
class PrecisionAtK(Metrics):
def __init__(self, qid, label, pred, k=1):
self.qid = qid
self.label = label
self.pred = pred
self.k = k
self.saver = {}
def reset(self):
self.saver = {}
@property
def tensor(self):
self.qid.persistable = True
self.label.persistable = True
self.pred.persistable = True
return [self.qid, self.label, self.pred]
def update(self, args):
qid, label, pred = args
if not (qid.shape[0] == label.shape[0] == pred.shape[0]):
raise ValueError('dimention not match: qid[%s] label[%s], pred[%s]'
% (qid.shape, label.shape, pred.shape))
qid = qid.reshape([-1]).tolist()
label = label.reshape([-1]).tolist()
pred = pred.reshape([-1]).tolist()
assert len(qid) == len(label) == len(pred)
for q, l, p in zip(qid, label, pred):
if q not in self.saver:
self.saver[q] = []
self.saver[q].append((l, p))
def eval(self):
right = 0
total = 0
for v in self.saver.values():
v = sorted(v, key=lambda x: x[1], reverse=True)
k = min(self.k, len(v))
for i in range(k):
if v[i][0] == 1:
right += 1
break
total += 1
return np.float32(1.0 * right / total)
#class SemanticRecallMetrics(Metrics):
# def __init__(self, qid, vec, type_id):
# self.qid = qid
# self.vec = vec
# self.type_id = type_id
# self.reset()
#
# def reset(self):
# self.saver = []
#
# @property
# def tensor(self):
# return [self.qid, self.vec, self.type_id]
#
# def update(self, args):
# qid, vec, type_id = args
# self.saver.append((qid, vec, type_id))
#
# def eval(self):
# dic = {}
# for qid, vec, type_id in self.saver():
# dic.setdefault(i, {}).setdefault(k, []).append(vec)
#
# for qid in dic:
# assert len(dic[qid]) == 3
# qvec = np.arrray(dic[qid][0])
# assert len(qvec) == 1
# ptvec = np.array(dic[qid][1])
# ntvec = np.array(dic[qid][2])
#
# np.matmul(qvec, np.transpose(ptvec))
# np.matmul(qvec, np.transpose(ntvec))
#
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import asyncio
import threading
import math
import zmq
import zmq.asyncio
import numpy as np
from propeller import log
import propeller.service.utils as serv_utils
class InferenceBaseClient(object):
def __init__(self, address):
self.context = zmq.Context()
self.address = address
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(address)
log.info("Connecting to server... %s" % address)
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
request = serv_utils.nparray_list_serialize(args)
self.socket.send(request)
reply = self.socket.recv()
ret = serv_utils.nparray_list_deserialize(reply)
return ret
class InferenceClient(InferenceBaseClient):
def __init__(self, address, batch_size=128, num_coroutine=10, timeout=10.):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
context = zmq.asyncio.Context()
self.socket_pool = [
context.socket(zmq.REQ) for _ in range(num_coroutine)
]
log.info("Connecting to server... %s" % address)
for socket in self.socket_pool:
socket.connect(address)
self.num_coroutine = num_coroutine
self.batch_size = batch_size
self.timeout = int(timeout * 1000)
#yapf: disable
def __call__(self, *args):
for arg in args:
if not isinstance(arg, np.ndarray):
raise ValueError('expect ndarray slot data, got %s' %
repr(arg))
num_tasks = math.ceil(1. * args[0].shape[0] / self.batch_size)
rets = [None] * num_tasks
async def get(coroutine_idx=0, num_coroutine=1):
socket = self.socket_pool[coroutine_idx]
while coroutine_idx < num_tasks:
begin = coroutine_idx * self.batch_size
end = (coroutine_idx + 1) * self.batch_size
arr_list = [arg[begin:end] for arg in args]
request = serv_utils.nparray_list_serialize(arr_list)
try:
await socket.send(request)
await socket.poll(self.timeout, zmq.POLLIN)
reply = await socket.recv(zmq.NOBLOCK)
ret = serv_utils.nparray_list_deserialize(reply)
except Exception as e:
log.exception(e)
ret = None
rets[coroutine_idx] = ret
coroutine_idx += num_coroutine
futures = [
get(i, self.num_coroutine) for i in range(self.num_coroutine)
]
self.loop.run_until_complete(asyncio.wait(futures))
for r in rets:
if r is None:
raise RuntimeError('Client call failed')
return [np.concatenate(col, 0) for col in zip(*rets)]
#yapf: enable
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册