提交 136a60d7 编写于 作者: C caoying03

update readme.

上级 501ce212
# 文本分类
以下是本例目录包含的文件以及对应说明`images` 文件夹以及 `index.html` 与使用无关可不关心):
以下是本例目录包含的文件以及对应说明:
```text
.
├── images
├── images # 文档中的图片
│   ├── cnn_net.png
│   └── dnn_net.png
├── index.html
├── infer.py # 预测任务脚本
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,希望进一步修改模型结构,请修改此文件
├── reader.py # 读取数据接口,若使用自定义格式的数据,可直接修改此文件
├── index.html # 文档
├── infer.py # 预测脚本
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,若进一步修改模型结构,请查看此文件
├── reader.py # 读取数据接口,若使用自定义格式的数据,请查看此文件
├── README.md # 文档
├── run.sh # 运行此脚本,可以以默认参数直接开始训练任务
├── train.py # 训练任务脚本
├── run.sh # 训练任务运行脚本,直接运行此脚本,将以默认参数开始训练任务
├── train.py # 训练脚本
└── utils.py # 定义通用的函数,例如:打印日志、解析命令行参数、构建字典、加载字典等
```
## 简介
文本分类任务根据给定一条文本的内容,判断该文本所属的类别,是自然语言处理领域的一项重要的基础任务。[PaddleBook](https://github.com/PaddlePaddle/book) 中的[情感分类](https://github.com/PaddlePaddle/book/blob/develop/06.understand_sentiment/README.cn.md)一课,正是一个典型的文本分类任务,任务流程如下:
1. 收集电影评论网站的用户评论数据。
......@@ -35,7 +34,7 @@
"No Free Lunch (NFL)" 是机器学习任务基本原则之一:没有任何一种模型是天生优于其他模型的。模型的设计和选择建立在了解不同模型特性的基础之上,但同时也是一个多次实验评估的过程。在本例中,我们继续向大家介绍几种最常用的文本分类模型,它们的能力和复杂程度不同,帮助大家对比学习这些模型学习效果之间的差异,针对不同的场景选择使用。
### DNN 模型与 CNN 模型
## 模型详解
`network_conf.py` 中包括以下模型:
......@@ -53,7 +52,6 @@
2. DNN 刻画的往往是频繁词特征,潜在会受到分词错误的影响,但对一些依赖关键词特征也能做的不错的任务:如 Spam 短信检测,依然是一个有效的模型。
3. 在大多数需要一定语义理解(例如,借助上下文消除语义中的歧义)的文本分类任务上,以 CNN / RNN 为代表的序列模型的效果往往好于 DNN 模型。
## 模型详解
### 1. DNN 模型
**DNN 模型结构入下图所示:**
......@@ -75,7 +73,7 @@
该 DNN 模型默认对输入的语料进行二分类(`class_dim=2`),embedding(词向量)维度默认为28(`emd_dim=28`),两个隐层均使用Tanh激活函数(`act=paddle.activation.Tanh()`)。需要注意的是,该模型的输入数据为整数序列,而不是原始的单词序列。事实上,为了处理方便,我们一般会事先将单词根据词频顺序进行 id 化,即将词语转化成在字典中的序号。
## 2. CNN 模型
### 2. CNN 模型
**CNN 模型结构如下图所示:**
......@@ -96,54 +94,53 @@
CNN 网络的输入数据类型和 DNN 一致。PaddlePaddle 中已经封装好的带有池化的文本序列卷积模块:`paddle.networks.sequence_conv_pool`,可直接调用。该模块的 `context_len` 参数用于指定卷积核在同一时间覆盖的文本长度,即图 2 中的卷积核的高度。`hidden_size` 用于指定该类型的卷积核的数量。本例代码默认使用了 128 个大小为 3 的卷积核和 128 个大小为 4 的卷积核,这些卷积的结果经过最大池化和结果拼接后产生一个 256 维的向量,向量经过一个全连接层输出最终的预测结果。
## 运行
### 使用 PaddlePaddle 内置的情感分类数据
## 使用 PaddlePaddle 内置数据运行
- 运行`sh run.sh` 将以 PaddlePaddle 内置的情感分类数据集:`paddle.dataset.imdb` 运行本例
- 运行 `python infer.py` 脚本加载训练好的模型进行预测。通过修改 `infer.py` 脚本中 `__main__` 函数中以下变量修改使用的模型和指定测试数据。脚本默认对 `paddle.dataset.imdb` 数据集中的测试数据进行测试。
### 如何训练
```python
model_path = "dnn_params_pass_00000.tar.gz" # 指定模型所在的路径
test_dir = None # 指定测试文件所在的目录,请注意,若不指定将默认使用paddle.dataset.imdb
word_dict = None # 指定字典所在的路径,请注意,若不指定将默认使用paddle.dataset.imdb
nn_type = "dnn" # 指定测试使用的模型
```
在终端中执行 `sh run.sh` 以下命令, 将以 PaddlePaddle 内置的情感分类数据集:`paddle.dataset.imdb` 直接运行本例,会看到如下输入:
### 使用自定义数据运行
```text
Pass 0, Batch 0, Cost 0.696031, {'__auc_evaluator_0__': 0.47360000014305115, 'classification_error_evaluator': 0.5}
Pass 0, Batch 100, Cost 0.544438, {'__auc_evaluator_0__': 0.839249312877655, 'classification_error_evaluator': 0.30000001192092896}
Pass 0, Batch 200, Cost 0.406581, {'__auc_evaluator_0__': 0.9030032753944397, 'classification_error_evaluator': 0.2199999988079071}
Test at Pass 0, {'__auc_evaluator_0__': 0.9289745092391968, 'classification_error_evaluator': 0.14927999675273895}
```
日志每隔 100 个 batch 输出一次,输出信息包括:(1)Pass 序号;(2)Batch 序号;(3)依次输出当前 Batch 上评估指标的评估结果。评估指标在配置网络拓扑结构时指定,在上面的输出中,输出了训练样本集之的 AUC 以及错误率指标。
#### step1. 编写自定义的数据读取接口
### 如何预测
例如有如下格式的数据:每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容。文本内容中的词语以空格分隔。以下是两条示例数据:
训练结束后模型默认存储在当前工作目录下,在终端中执行 `python infer.py` ,预测脚本会加载训练好的模型进行预测。
```
negative PaddlePaddle is good
positive What a terrible weather
- 默认加载使用 `paddle.data.imdb.train` 训练一个 Pass 产出的 DNN 模型对 `paddle.dataset.imdb.test` 进行测试
会看到如下输出:
```text
positive 0.9275 0.0725 previous reviewer <unk> <unk> gave a much better <unk> of the films plot details than i could what i recall mostly is that it was just so beautiful in every sense emotionally visually <unk> just <unk> br if you like movies that are wonderful to look at and also have emotional content to which that beauty is relevant i think you will be glad to have seen this extraordinary and unusual work of <unk> br on a scale of 1 to 10 id give it about an <unk> the only reason i shy away from 9 is that it is a mood piece if you are in the mood for a really artistic very romantic film then its a 10 i definitely think its a mustsee but none of us can be in that mood all the time so overall <unk>
negative 0.0300 0.9700 i love scifi and am willing to put up with a lot scifi <unk> are usually <unk> <unk> and <unk> i tried to like this i really did but it is to good tv scifi as <unk> 5 is to star trek the original silly <unk> cheap cardboard sets stilted dialogues cg that doesnt match the background and painfully onedimensional characters cannot be overcome with a scifi setting im sure there are those of you out there who think <unk> 5 is good scifi tv its not its clichéd and <unk> while us viewers might like emotion and character development scifi is a genre that does not take itself seriously <unk> star trek it may treat important issues yet not as a serious philosophy its really difficult to care about the characters here as they are not simply <unk> just missing a <unk> of life their actions and reactions are wooden and predictable often painful to watch the makers of earth know its rubbish as they have to always say gene <unk> earth otherwise people would not continue watching <unk> <unk> must be turning in their <unk> as this dull cheap poorly edited watching it without <unk> breaks really brings this home <unk> <unk> of a show <unk> into space spoiler so kill off a main character and then bring him back as another actor <unk> <unk> all over again
```
编写自定义的数据读取接口关键在实现一个 Python 生成器完成**从原始输入文本中解析一条训练样本的逻辑**
输出日志每一行是对一条样本预测的结果,以 `\t` 分隔,共 3 列,分别是:(1)预测类别标签;(2)样本分别属于每一类的概率,内部以空格分隔;(3)输入文本
以下代码片段实现了:读取以上格式数据返回类型为: `paddle.data_type.integer_value_sequence`(词语在字典的序号)和 `paddle.data_type.integer_value`(类别标签)的 2 个输入给网络中中定义的 2 个 `data_layer`(见 `fc_net``convolution_net`)。
## 使用自定义数据训练和预测
关于 PaddlePaddle 中 `data_layer` 接受输入数据的类型,以及读取数据接口应该返回数据的格式,请参考 [input-types](http://www.paddlepaddle.org/release_doc/0.9.0/doc_cn/ui/data_provider/pydataprovider2.html#input-types) 一节。
### 如何训练
- `data_dir` 测试数据所在路径
- `word_dict` 词语的字典,用来将原始字符串表示的词语转化为字典中的序号
- `label_dict` 类别标签的字典,用于将字符串的类别标签,转换成整数类型的序号
1. 数据组织
```python
def train_reader(data_dir, word_dict, label_dict):
"""
Reader interface for training data
假设有如下格式的训练数据:每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容,文本内容中的词语以空格分隔。以下是两条示例数据:
:param data_dir: data directory
:type data_dir: str
:param word_dict: path of word dictionary,
the dictionary must has a "UNK" in it.
:type word_dict: Python dict
:param label_dict: path of label dictionary
:type label_dict: Python dict
"""
```
positive PaddlePaddle is good
negative What a terrible weather
```
2. 编写数据读取接口
自定义数据读取接口只需编写一个 Python 生成器实现**从原始输入文本中解析一条训练样本**的逻辑。以下代码片段实现了读取原始数据返回类型为: `paddle.data_type.integer_value_sequence`(词语在字典的序号)和 `paddle.data_type.integer_value`(类别标签)的 2 个输入给网络中定义的 2 个 `data_layer` 的功能。
```python
def train_reader(data_dir, word_dict, label_dict):
def reader():
UNK_ID = word_dict["<UNK>"]
word_col = 0
......@@ -160,32 +157,42 @@ def train_reader(data_dir, word_dict, label_dict):
yield word_ids, label_dict[line_split[lbl_col]]
return reader
```
```
本例目录下的 `reader.py` 含有读取训练和测试数据的全部代码。
- 关于 PaddlePaddle 中 `data_layer` 接受输入数据的类型,以及数据读取接口对应该返回数据的格式,请参考 [input-types](http://www.paddlepaddle.org/release_doc/0.9.0/doc_cn/ui/data_provider/pydataprovider2.html#input-types) 一节。
- 以上代码片段详见本例目录下的 `reader.py` 脚本,`reader.py` 同时提供了读取测试数据的全部代码。
接下来,只需要将数据读取函数 `train_reader` 作为参数传递给 `train.py` 脚本中的 `paddle.batch` 接口即可使用自定义数据接口读取数据,调用方式如下:
接下来,只需要将数据读取函数 `train_reader` 作为参数传递给 `train.py` 脚本中的 `paddle.batch` 接口即可使用自定义数据接口读取数据,调用方式如下:
```python
train_reader = paddle.batch(
```python
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_reader(train_data_dir, word_dict, lbl_dict),
buf_size=1000),
batch_size=batch_size)
```
```
#### step 2. 修改命令行参数
3. 修改命令行参数
执行 `python train.py --help` 可以获取`train.py` 脚本各项启动参数的详细说明。通过修改 `train.py` 脚本的启动参数,指定自定义数据的路径。
- 如果将数据组织成示例数据的同样的格式,只需在 `run.sh` 脚本中修改 `train.py` 启动参数,指定 `train_data_dir` 参数,可以直接运行本例,无需修改数据读取接口 `reader.py`。
- 执行 `python train.py --help` 可以获取`train.py` 脚本各项启动参数的详细说明,主要参数如下:
- `nn_type`:选择要使用的模型,目前支持两种:“dnn” 或者 “cnn”。
- `train_data_dir`:指定训练数据所在的文件夹,使用自定义数据训练,必须指定此参数,否则使用`paddle.dataset.imdb`训练,同时忽略`test_data_dir`,`word_dict`,和 `label_dict` 参数。
- `test_data_dir`:指定测试数据所在的文件夹,若不指定将不进行测试。
- `word_dict`:字典文件所在的路径,若不指定,将从训练数据根据词频统计,自动建立字典。
- `label_dict`:类别标签字典,用于将字符串类型的类别标签,映射为整数类型的序号。
- `batch_size`:指定多少条样本后进行一次神经网络的前向运行及反向更新。
- `num_passes`:指定训练多少个轮次。
主要参数如下:
### 如何预测
- `nn_type`:选择要使用的模型,目前支持两种:“dnn” 或者 “cnn”。
- `train_data_dir`:指定训练数据所在的文件夹,使用自定义数据训练,必须指定此参数,否则使用`paddle.dataset.imdb`训练,同时忽略`test_data_dir``word_dict`,和 `label_dict` 参数。
- `test_data_dir`:指定测试数据所在的文件夹,若不指定将不进行测试。
- `word_dict`:字典文件所在的路径,若不指定,将从训练数据根据词频统计,自动建立字典。
- `label_dict`:类别标签字典,用于将字符串类型的类别标签,映射为整数类型的序号。
- `batch_size`:指定多少条样本后进行一次神经网络的前向运行及反向更新。
- `num_passes`:指定训练多少个轮次。
1. 修改 `infer.py` 中以下变量,指定使用的模型、指定测试数据。
如果将数据组织成上一节示例数据的格式,只需在 `run.sh` 脚本中指定 `train_data_dir` 参数,可以直接运行本例,无需修改数据读取接口 `reader.py`
```python
model_path = "dnn_params_pass_00000.tar.gz" # 指定模型所在的路径
nn_type = "dnn" # 指定测试使用的模型
test_dir = "./data/test" # 指定测试文件所在的目录
word_dict = "./data/dict/word_dict.txt" # 指定字典所在的路径
label_dict = "./data/dict/label_dict.txt" # 指定类别标签字典的路径
```
2. 在终端中执行 `python infer.py`
......@@ -42,25 +42,24 @@
<div id="markdown" style='display:none'>
# 文本分类
以下是本例目录包含的文件以及对应说明(`images` 文件夹以及 `index.html` 与使用无关可不关心):
以下是本例目录包含的文件以及对应说明:
```text
.
├── images
├── images # 文档中的图片
│   ├── cnn_net.png
│   └── dnn_net.png
├── index.html
├── infer.py # 预测任务脚本
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,希望进一步修改模型结构,请修改此文件
├── reader.py # 读取数据接口,若使用自定义格式的数据,可直接修改此文件
├── index.html # 文档
├── infer.py # 预测脚本
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,若进一步修改模型结构,请查看此文件
├── reader.py # 读取数据接口,若使用自定义格式的数据,请查看此文件
├── README.md # 文档
├── run.sh # 运行此脚本,可以以默认参数直接开始训练任务
├── train.py # 训练任务脚本
├── run.sh # 训练任务运行脚本,直接运行此脚本,将以默认参数开始训练任务
├── train.py # 训练脚本
└── utils.py # 定义通用的函数,例如:打印日志、解析命令行参数、构建字典、加载字典等
```
## 简介
文本分类任务根据给定一条文本的内容,判断该文本所属的类别,是自然语言处理领域的一项重要的基础任务。[PaddleBook](https://github.com/PaddlePaddle/book) 中的[情感分类](https://github.com/PaddlePaddle/book/blob/develop/06.understand_sentiment/README.cn.md)一课,正是一个典型的文本分类任务,任务流程如下:
1. 收集电影评论网站的用户评论数据。
......@@ -77,7 +76,7 @@
"No Free Lunch (NFL)" 是机器学习任务基本原则之一:没有任何一种模型是天生优于其他模型的。模型的设计和选择建立在了解不同模型特性的基础之上,但同时也是一个多次实验评估的过程。在本例中,我们继续向大家介绍几种最常用的文本分类模型,它们的能力和复杂程度不同,帮助大家对比学习这些模型学习效果之间的差异,针对不同的场景选择使用。
### DNN 模型与 CNN 模型
## 模型详解
`network_conf.py` 中包括以下模型:
......@@ -95,7 +94,6 @@
2. DNN 刻画的往往是频繁词特征,潜在会受到分词错误的影响,但对一些依赖关键词特征也能做的不错的任务:如 Spam 短信检测,依然是一个有效的模型。
3. 在大多数需要一定语义理解(例如,借助上下文消除语义中的歧义)的文本分类任务上,以 CNN / RNN 为代表的序列模型的效果往往好于 DNN 模型。
## 模型详解
### 1. DNN 模型
**DNN 模型结构入下图所示:**
......@@ -117,7 +115,7 @@
该 DNN 模型默认对输入的语料进行二分类(`class_dim=2`),embedding(词向量)维度默认为28(`emd_dim=28`),两个隐层均使用Tanh激活函数(`act=paddle.activation.Tanh()`)。需要注意的是,该模型的输入数据为整数序列,而不是原始的单词序列。事实上,为了处理方便,我们一般会事先将单词根据词频顺序进行 id 化,即将词语转化成在字典中的序号。
## 2. CNN 模型
### 2. CNN 模型
**CNN 模型结构如下图所示:**
......@@ -138,54 +136,53 @@
CNN 网络的输入数据类型和 DNN 一致。PaddlePaddle 中已经封装好的带有池化的文本序列卷积模块:`paddle.networks.sequence_conv_pool`,可直接调用。该模块的 `context_len` 参数用于指定卷积核在同一时间覆盖的文本长度,即图 2 中的卷积核的高度。`hidden_size` 用于指定该类型的卷积核的数量。本例代码默认使用了 128 个大小为 3 的卷积核和 128 个大小为 4 的卷积核,这些卷积的结果经过最大池化和结果拼接后产生一个 256 维的向量,向量经过一个全连接层输出最终的预测结果。
## 运行
### 使用 PaddlePaddle 内置的情感分类数据
## 使用 PaddlePaddle 内置数据运行
- 运行`sh run.sh` 将以 PaddlePaddle 内置的情感分类数据集:`paddle.dataset.imdb` 运行本例
- 运行 `python infer.py` 脚本加载训练好的模型进行预测。通过修改 `infer.py` 脚本中 `__main__` 函数中以下变量修改使用的模型和指定测试数据。脚本默认对 `paddle.dataset.imdb` 数据集中的测试数据进行测试。
### 如何训练
```python
model_path = "dnn_params_pass_00000.tar.gz" # 指定模型所在的路径
test_dir = None # 指定测试文件所在的目录,请注意,若不指定将默认使用paddle.dataset.imdb
word_dict = None # 指定字典所在的路径,请注意,若不指定将默认使用paddle.dataset.imdb
nn_type = "dnn" # 指定测试使用的模型
```
在终端中执行 `sh run.sh` 以下命令, 将以 PaddlePaddle 内置的情感分类数据集:`paddle.dataset.imdb` 直接运行本例,会看到如下输入:
### 使用自定义数据运行
```text
Pass 0, Batch 0, Cost 0.696031, {'__auc_evaluator_0__': 0.47360000014305115, 'classification_error_evaluator': 0.5}
Pass 0, Batch 100, Cost 0.544438, {'__auc_evaluator_0__': 0.839249312877655, 'classification_error_evaluator': 0.30000001192092896}
Pass 0, Batch 200, Cost 0.406581, {'__auc_evaluator_0__': 0.9030032753944397, 'classification_error_evaluator': 0.2199999988079071}
Test at Pass 0, {'__auc_evaluator_0__': 0.9289745092391968, 'classification_error_evaluator': 0.14927999675273895}
```
日志每隔 100 个 batch 输出一次,输出信息包括:(1)Pass 序号;(2)Batch 序号;(3)依次输出当前 Batch 上评估指标的评估结果。评估指标在配置网络拓扑结构时指定,在上面的输出中,输出了训练样本集之的 AUC 以及错误率指标。
#### step1. 编写自定义的数据读取接口
### 如何预测
例如有如下格式的数据:每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容。文本内容中的词语以空格分隔。以下是两条示例数据:
训练结束后模型默认存储在当前工作目录下,在终端中执行 `python infer.py` ,预测脚本会加载训练好的模型进行预测。
```
negative PaddlePaddle is good
positive What a terrible weather
- 默认加载使用 `paddle.data.imdb.train` 训练一个 Pass 产出的 DNN 模型对 `paddle.dataset.imdb.test` 进行测试
会看到如下输出:
```text
positive 0.9275 0.0725 previous reviewer <unk> <unk> gave a much better <unk> of the films plot details than i could what i recall mostly is that it was just so beautiful in every sense emotionally visually <unk> just <unk> br if you like movies that are wonderful to look at and also have emotional content to which that beauty is relevant i think you will be glad to have seen this extraordinary and unusual work of <unk> br on a scale of 1 to 10 id give it about an <unk> the only reason i shy away from 9 is that it is a mood piece if you are in the mood for a really artistic very romantic film then its a 10 i definitely think its a mustsee but none of us can be in that mood all the time so overall <unk>
negative 0.0300 0.9700 i love scifi and am willing to put up with a lot scifi <unk> are usually <unk> <unk> and <unk> i tried to like this i really did but it is to good tv scifi as <unk> 5 is to star trek the original silly <unk> cheap cardboard sets stilted dialogues cg that doesnt match the background and painfully onedimensional characters cannot be overcome with a scifi setting im sure there are those of you out there who think <unk> 5 is good scifi tv its not its clichéd and <unk> while us viewers might like emotion and character development scifi is a genre that does not take itself seriously <unk> star trek it may treat important issues yet not as a serious philosophy its really difficult to care about the characters here as they are not simply <unk> just missing a <unk> of life their actions and reactions are wooden and predictable often painful to watch the makers of earth know its rubbish as they have to always say gene <unk> earth otherwise people would not continue watching <unk> <unk> must be turning in their <unk> as this dull cheap poorly edited watching it without <unk> breaks really brings this home <unk> <unk> of a show <unk> into space spoiler so kill off a main character and then bring him back as another actor <unk> <unk> all over again
```
编写自定义的数据读取接口关键在实现一个 Python 生成器完成**从原始输入文本中解析一条训练样本的逻辑**
输出日志每一行是对一条样本预测的结果,以 `\t` 分隔,共 3 列,分别是:(1)预测类别标签;(2)样本分别属于每一类的概率,内部以空格分隔;(3)输入文本
以下代码片段实现了:读取以上格式数据返回类型为: `paddle.data_type.integer_value_sequence`(词语在字典的序号)和 `paddle.data_type.integer_value`(类别标签)的 2 个输入给网络中中定义的 2 个 `data_layer`(见 `fc_net` 或 `convolution_net`)。
## 使用自定义数据训练和预测
关于 PaddlePaddle 中 `data_layer` 接受输入数据的类型,以及读取数据接口应该返回数据的格式,请参考 [input-types](http://www.paddlepaddle.org/release_doc/0.9.0/doc_cn/ui/data_provider/pydataprovider2.html#input-types) 一节。
### 如何训练
- `data_dir` 测试数据所在路径
- `word_dict` 词语的字典,用来将原始字符串表示的词语转化为字典中的序号
- `label_dict` 类别标签的字典,用于将字符串的类别标签,转换成整数类型的序号
1. 数据组织
```python
def train_reader(data_dir, word_dict, label_dict):
"""
Reader interface for training data
假设有如下格式的训练数据:每一行为一条样本,以 `\t` 分隔,第一列是类别标签,第二列是输入文本的内容,文本内容中的词语以空格分隔。以下是两条示例数据:
:param data_dir: data directory
:type data_dir: str
:param word_dict: path of word dictionary,
the dictionary must has a "UNK" in it.
:type word_dict: Python dict
:param label_dict: path of label dictionary
:type label_dict: Python dict
"""
```
positive PaddlePaddle is good
negative What a terrible weather
```
2. 编写数据读取接口
自定义数据读取接口只需编写一个 Python 生成器实现**从原始输入文本中解析一条训练样本**的逻辑。以下代码片段实现了读取原始数据返回类型为: `paddle.data_type.integer_value_sequence`(词语在字典的序号)和 `paddle.data_type.integer_value`(类别标签)的 2 个输入给网络中定义的 2 个 `data_layer` 的功能。
```python
def train_reader(data_dir, word_dict, label_dict):
def reader():
UNK_ID = word_dict["<UNK>"]
word_col = 0
......@@ -202,35 +199,45 @@ def train_reader(data_dir, word_dict, label_dict):
yield word_ids, label_dict[line_split[lbl_col]]
return reader
```
```
本例目录下的 `reader.py` 含有读取训练和测试数据的全部代码。
- 关于 PaddlePaddle 中 `data_layer` 接受输入数据的类型,以及数据读取接口对应该返回数据的格式,请参考 [input-types](http://www.paddlepaddle.org/release_doc/0.9.0/doc_cn/ui/data_provider/pydataprovider2.html#input-types) 一节。
- 以上代码片段详见本例目录下的 `reader.py` 脚本,`reader.py` 同时提供了读取测试数据的全部代码。
接下来,只需要将数据读取函数 `train_reader` 作为参数传递给 `train.py` 脚本中的 `paddle.batch` 接口即可使用自定义数据接口读取数据,调用方式如下:
接下来,只需要将数据读取函数 `train_reader` 作为参数传递给 `train.py` 脚本中的 `paddle.batch` 接口即可使用自定义数据接口读取数据,调用方式如下:
```python
train_reader = paddle.batch(
```python
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_reader(train_data_dir, word_dict, lbl_dict),
buf_size=1000),
batch_size=batch_size)
```
```
#### step 2. 修改命令行参数
3. 修改命令行参数
执行 `python train.py --help` 可以获取`train.py` 脚本各项启动参数的详细说明。通过修改 `train.py` 脚本的启动参数,指定自定义数据的路径。
- 如果将数据组织成示例数据的同样的格式,只需在 `run.sh` 脚本中修改 `train.py` 启动参数,指定 `train_data_dir` 参数,可以直接运行本例,无需修改数据读取接口 `reader.py`。
- 执行 `python train.py --help` 可以获取`train.py` 脚本各项启动参数的详细说明,主要参数如下:
- `nn_type`:选择要使用的模型,目前支持两种:“dnn” 或者 “cnn”。
- `train_data_dir`:指定训练数据所在的文件夹,使用自定义数据训练,必须指定此参数,否则使用`paddle.dataset.imdb`训练,同时忽略`test_data_dir`,`word_dict`,和 `label_dict` 参数。
- `test_data_dir`:指定测试数据所在的文件夹,若不指定将不进行测试。
- `word_dict`:字典文件所在的路径,若不指定,将从训练数据根据词频统计,自动建立字典。
- `label_dict`:类别标签字典,用于将字符串类型的类别标签,映射为整数类型的序号。
- `batch_size`:指定多少条样本后进行一次神经网络的前向运行及反向更新。
- `num_passes`:指定训练多少个轮次。
主要参数如下:
### 如何预测
- `nn_type`:选择要使用的模型,目前支持两种:“dnn” 或者 “cnn”。
- `train_data_dir`:指定训练数据所在的文件夹,使用自定义数据训练,必须指定此参数,否则使用`paddle.dataset.imdb`训练,同时忽略`test_data_dir`,`word_dict`,和 `label_dict` 参数。
- `test_data_dir`:指定测试数据所在的文件夹,若不指定将不进行测试。
- `word_dict`:字典文件所在的路径,若不指定,将从训练数据根据词频统计,自动建立字典。
- `label_dict`:类别标签字典,用于将字符串类型的类别标签,映射为整数类型的序号。
- `batch_size`:指定多少条样本后进行一次神经网络的前向运行及反向更新。
- `num_passes`:指定训练多少个轮次。
1. 修改 `infer.py` 中以下变量,指定使用的模型、指定测试数据。
如果将数据组织成上一节示例数据的格式,只需在 `run.sh` 脚本中指定 `train_data_dir` 参数,可以直接运行本例,无需修改数据读取接口 `reader.py`。
```python
model_path = "dnn_params_pass_00000.tar.gz" # 指定模型所在的路径
nn_type = "dnn" # 指定测试使用的模型
test_dir = "./data/test" # 指定测试文件所在的目录
word_dict = "./data/dict/word_dict.txt" # 指定字典所在的路径
label_dict = "./data/dict/label_dict.txt" # 指定类别标签字典的路径
```
2. 在终端中执行 `python infer.py`。
</div>
<!-- You can change the lines below now. -->
......
......@@ -13,20 +13,22 @@ from utils import *
def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
batch_size):
def _infer_a_batch(inferer, test_batch):
def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label):
probs = inferer.infer(input=test_batch, field=['value'])
assert len(probs) == len(test_batch)
for prob in probs:
lab = prob.argmax()
print("%d\t%s\t%s" %
(lab, label_reverse_dict[lab],
"\t".join(["{:0.4f}".format(p) for p in prob])))
for word_ids, prob in zip(test_batch, probs):
word_text = " ".join([ids_2_word[id] for id in word_ids[0]])
print("%s\t%s\t%s" % (ids_2_label[prob.argmax()],
" ".join(["{:0.4f}".format(p)
for p in prob]), word_text))
logger.info('begin to predict...')
use_default_data = (data_dir is None)
if use_default_data:
word_dict = paddle.dataset.imdb.word_dict()
word_reverse_dict = dict((value, key)
for key, value in word_dict.iteritems())
label_reverse_dict = {0: "positive", 1: "negative"}
test_reader = paddle.dataset.imdb.test(word_dict)
else:
......@@ -34,7 +36,9 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
word_dict_path), 'the word dictionary file does not exist'
assert os.path.exists(
label_dict_path), 'the label dictionary file does not exist'
word_dict = load_dict(word_dict_path)
word_reverse_dict = load_reverse_dict(word_dict_path)
label_reverse_dict = load_reverse_dict(label_dict_path)
test_reader = reader.test_reader(data_dir, word_dict)()
......@@ -56,10 +60,13 @@ def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
for idx, item in enumerate(test_reader):
test_batch.append([item[0]])
if len(test_batch) == batch_size:
_infer_a_batch(inferer, test_batch)
_infer_a_batch(inferer, test_batch, word_reverse_dict,
label_reverse_dict)
test_batch = []
_infer_a_batch(inferer, test_batch)
if len(test_batch):
_infer_a_batch(inferer, test_batch, word_reverse_dict,
label_reverse_dict)
test_batch = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册