From 763615ea4722fc6173f63f988783986586439d63 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 9 May 2017 14:21:49 +0800 Subject: [PATCH] finish README.md --- text_classification/README.md | 32 +++++++++++++++++-- .../text_classification_cnn.py | 4 +-- .../text_classification_dnn.py | 2 +- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/text_classification/README.md b/text_classification/README.md index 4cdc457f..e23692fd 100644 --- a/text_classification/README.md +++ b/text_classification/README.md @@ -68,9 +68,9 @@ def fc_net(input_dim, class_dim=2, emb_dim=256): 需要注意的是,该模型的输入数据为整数序列,而不是原始的英文单词序列。事实上,为了处理方便我们一般会事先将单词根据词频顺序进行id化,即将单词用整数替代。这一步一般在DNN模型之外完成。 -##CNN模型 +## CNN模型 -####CNN的模型结构如下图所示: +#### CNN的模型结构如下图所示:


@@ -117,3 +117,31 @@ def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128): ``` 该CNN网络的输入数据类型和前面介绍过的DNN一致。`paddle.networks.sequence_conv_pool`为Paddle中已经封装好的带有pooling的文本序列卷积模块,该模块的`context_len`参数用于指定卷积核在同一时间覆盖的文本长度,也即图2中的卷积核的高度;`hidden_size`用于指定该类型的卷积核的数量。可以看到,上述代码定义的结构中使用了128个大小为3的卷积核和128个大小为4的卷积核,这些卷积的结果经过max pooling和结果并置后产生一个256维的向量,向量经过一个全连接层输出最终预测结果。 + +## 运行与输出 + +本部分以上文介绍的DNN网络为例,介绍如何利用样例中的`text_classification_dnn.py`脚本进行DNN网络的训练和对新样本的预测。 + +`text_classification_dnn.py`中的代码分为四部分: + +- **fc_net函数**:定义dnn网络结构,上文已经有说明。 + +- **train\_dnn\_model函数**:模型训练函数。定义优化方式、训练输出等内容,并组织训练流程。该函数运行完成前会将训练得到的参数保保存至硬盘上的`dnn_params.tar.gz`文件中。本函数接受一个整数类型的参数,表示训练pass的轮数。 + +- **dnn_infer函数**:载入已有模型并对新样本进行预测。函数开始运行后会从当前路径下寻找并读取`dnn_params.tar.gz`文件,加载其中的模型,并对test数据集中的前100条样本进行预测。 + +- **main函数**:主函数 + +要运行本样例,直接在`text_classification_dnn.py`所在路径下执行`python ./text_classification_dnn.py`即可,样例会自动依次执行数据读取、模型训练和保存、模型读取、新样本预测等步骤。 + +预测的输出形式为: + +``` +[ 0.99892634 0.00107362] 0 +[ 0.00107638 0.9989236 ] 1 +[ 0.98185927 0.01814074] 0 +[ 0.31667888 0.68332112] 1 +[ 0.98853314 0.01146684] 0 +``` + +每一行表示一条样本的预测结果。前两列表示该样本属于正负这两个类别的预测概率,最后一列表示样本的实际label。 diff --git a/text_classification/text_classification_cnn.py b/text_classification/text_classification_cnn.py index ba46c8c6..7dea7774 100644 --- a/text_classification/text_classification_cnn.py +++ b/text_classification/text_classification_cnn.py @@ -126,6 +126,6 @@ def cnn_infer(): if __name__ == "__main__": - paddle.init(use_gpu=False, trainer_count=10) - train_cnn_model(num_pass=10) + paddle.init(use_gpu=False, trainer_count=4) + train_cnn_model(num_pass=5) cnn_infer() diff --git a/text_classification/text_classification_dnn.py b/text_classification/text_classification_dnn.py index cb390041..96d1fed2 100644 --- a/text_classification/text_classification_dnn.py +++ b/text_classification/text_classification_dnn.py @@ -139,5 +139,5 @@ def dnn_infer(): if __name__ == "__main__": paddle.init(use_gpu=False, trainer_count=4) - train_dnn_model(2) + train_dnn_model(num_pass=5) dnn_infer() -- GitLab