You need to sign in or sign up before continuing.
提交 763615ea 编写于 作者: F fengjiayi

finish README.md

上级 947e70d3
...@@ -68,9 +68,9 @@ def fc_net(input_dim, class_dim=2, emb_dim=256): ...@@ -68,9 +68,9 @@ def fc_net(input_dim, class_dim=2, emb_dim=256):
需要注意的是,该模型的输入数据为整数序列,而不是原始的英文单词序列。事实上,为了处理方便我们一般会事先将单词根据词频顺序进行id化,即将单词用整数替代。这一步一般在DNN模型之外完成。 需要注意的是,该模型的输入数据为整数序列,而不是原始的英文单词序列。事实上,为了处理方便我们一般会事先将单词根据词频顺序进行id化,即将单词用整数替代。这一步一般在DNN模型之外完成。
##CNN模型 ## CNN模型
####CNN的模型结构如下图所示: #### CNN的模型结构如下图所示:
<p align="center"> <p align="center">
<img src="images/cnn_net.png" width = "90%" align="center"/><br/> <img src="images/cnn_net.png" width = "90%" align="center"/><br/>
...@@ -117,3 +117,31 @@ def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128): ...@@ -117,3 +117,31 @@ def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128):
``` ```
该CNN网络的输入数据类型和前面介绍过的DNN一致。`paddle.networks.sequence_conv_pool`为Paddle中已经封装好的带有pooling的文本序列卷积模块,该模块的`context_len`参数用于指定卷积核在同一时间覆盖的文本长度,也即图2中的卷积核的高度;`hidden_size`用于指定该类型的卷积核的数量。可以看到,上述代码定义的结构中使用了128个大小为3的卷积核和128个大小为4的卷积核,这些卷积的结果经过max pooling和结果并置后产生一个256维的向量,向量经过一个全连接层输出最终预测结果。 该CNN网络的输入数据类型和前面介绍过的DNN一致。`paddle.networks.sequence_conv_pool`为Paddle中已经封装好的带有pooling的文本序列卷积模块,该模块的`context_len`参数用于指定卷积核在同一时间覆盖的文本长度,也即图2中的卷积核的高度;`hidden_size`用于指定该类型的卷积核的数量。可以看到,上述代码定义的结构中使用了128个大小为3的卷积核和128个大小为4的卷积核,这些卷积的结果经过max pooling和结果并置后产生一个256维的向量,向量经过一个全连接层输出最终预测结果。
## 运行与输出
本部分以上文介绍的DNN网络为例,介绍如何利用样例中的`text_classification_dnn.py`脚本进行DNN网络的训练和对新样本的预测。
`text_classification_dnn.py`中的代码分为四部分:
- **fc_net函数**:定义dnn网络结构,上文已经有说明。
- **train\_dnn\_model函数**:模型训练函数。定义优化方式、训练输出等内容,并组织训练流程。该函数运行完成前会将训练得到的参数保保存至硬盘上的`dnn_params.tar.gz`文件中。本函数接受一个整数类型的参数,表示训练pass的轮数。
- **dnn_infer函数**:载入已有模型并对新样本进行预测。函数开始运行后会从当前路径下寻找并读取`dnn_params.tar.gz`文件,加载其中的模型,并对test数据集中的前100条样本进行预测。
- **main函数**:主函数
要运行本样例,直接在`text_classification_dnn.py`所在路径下执行`python ./text_classification_dnn.py`即可,样例会自动依次执行数据读取、模型训练和保存、模型读取、新样本预测等步骤。
预测的输出形式为:
```
[ 0.99892634 0.00107362] 0
[ 0.00107638 0.9989236 ] 1
[ 0.98185927 0.01814074] 0
[ 0.31667888 0.68332112] 1
[ 0.98853314 0.01146684] 0
```
每一行表示一条样本的预测结果。前两列表示该样本属于正负这两个类别的预测概率,最后一列表示样本的实际label。
...@@ -126,6 +126,6 @@ def cnn_infer(): ...@@ -126,6 +126,6 @@ def cnn_infer():
if __name__ == "__main__": if __name__ == "__main__":
paddle.init(use_gpu=False, trainer_count=10) paddle.init(use_gpu=False, trainer_count=4)
train_cnn_model(num_pass=10) train_cnn_model(num_pass=5)
cnn_infer() cnn_infer()
...@@ -139,5 +139,5 @@ def dnn_infer(): ...@@ -139,5 +139,5 @@ def dnn_infer():
if __name__ == "__main__": if __name__ == "__main__":
paddle.init(use_gpu=False, trainer_count=4) paddle.init(use_gpu=False, trainer_count=4)
train_dnn_model(2) train_dnn_model(num_pass=5)
dnn_infer() dnn_infer()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册