From 344a9a43c8e086a1fdf4a229da7ccb8c7d13dedb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 12 May 2017 11:19:50 +0800 Subject: [PATCH] add section of 'self-define data reader' into README.md --- text_classification/README.md | 32 +++++++++++++++++++ .../text_classification_cnn.py | 2 +- .../text_classification_dnn.py | 2 +- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/text_classification/README.md b/text_classification/README.md index 7a3f8c34..0f5d4923 100644 --- a/text_classification/README.md +++ b/text_classification/README.md @@ -126,6 +126,38 @@ def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128): 该CNN网络的输入数据类型和前面介绍过的DNN一致。`paddle.networks.sequence_conv_pool`为Paddle中已经封装好的带有池化的文本序列卷积模块,该模块的`context_len`参数用于指定卷积核在同一时间覆盖的文本长度,也即图2中的卷积核的高度;`hidden_size`用于指定该类型的卷积核的数量。可以看到,上述代码定义的结构中使用了128个大小为3的卷积核和128个大小为4的卷积核,这些卷积的结果经过最大池化和结果并置后产生一个256维的向量,向量经过一个全连接层输出最终预测结果。 +## 自定义数据 +上面的代码样例中使用的都是PaddlePaddle自带的样例数据,如果用户希望使用其他数据进行测试,需要自行编写数据读取接口。 + +编写数据读取接口的关键在于实现一个Python生成器,生成器负责解析数据文件中的每一行内容,并组合成适当的数据形式传送给网络中的data layer。例如在本样例中,data layer需要的数据类型为`paddle.data_type.integer_value_sequence`,这本质上是一个Python list。因此我们的生成器需要完成的主要就是“从文件中读取数据”和“转换成适当形式的Python list”这两件事。 + +假设我们的数据的内容形式为: + +``` +PaddlePaddle is good 1 +What a terrible weather 0 +``` +每一行为一条样本,样本包括了原始语料和标签,语料内部的单词空格分隔,语料和标签之间用`\t`分隔。对于这样的数据我们可以如下编写数据读取接口: + +```python +def encode_word(word, word_dict): + if word_dict.has_key(word): + return word_dict[word] + else: + return word_dict[''] + +def data_reader(file_name, word_dict): + def data_reader(): + with open(file_name, "r") as f: + for line in f: + ins, label = line.strip('\n').split('\t') + ins_data = [int(encode_word(w, word_dict)) for w in ins.split(' ')] + yield ins_data, int(label) + return data_reader +``` + +其中`word_dict`为事先准备好的将单词映射为id的词表。该`data_reader`可以替换代码中原先的`Paddle.dataset.imdb.train`用以数据提供。 + ## 运行与输出 本部分以上文介绍的DNN网络为例,介绍如何利用样例中的`text_classification_dnn.py`脚本进行DNN网络的训练和对新样本的预测。 diff --git a/text_classification/text_classification_cnn.py b/text_classification/text_classification_cnn.py index 43d6a6b0..564720b3 100644 --- a/text_classification/text_classification_cnn.py +++ b/text_classification/text_classification_cnn.py @@ -112,7 +112,7 @@ def cnn_infer(file_name): if __name__ == "__main__": - paddle.init(use_gpu=False, trainer_count=4) + paddle.init(use_gpu=False, trainer_count=1) num_pass = 5 train_cnn_model(num_pass=num_pass) param_file_name = "cnn_params_pass" + str(num_pass - 1) + ".tar.gz" diff --git a/text_classification/text_classification_dnn.py b/text_classification/text_classification_dnn.py index c3f76fa8..4bb6445c 100644 --- a/text_classification/text_classification_dnn.py +++ b/text_classification/text_classification_dnn.py @@ -129,7 +129,7 @@ def dnn_infer(file_name): if __name__ == "__main__": - paddle.init(use_gpu=False, trainer_count=4) + paddle.init(use_gpu=False, trainer_count=1) num_pass = 5 train_dnn_model(num_pass=num_pass) param_file_name = "dnn_params_pass" + str(num_pass - 1) + ".tar.gz" -- GitLab