Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
344a9a43
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
344a9a43
编写于
5月 12, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add section of 'self-define data reader' into README.md
上级
18a4bd97
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
34 addition
and
2 deletion
+34
-2
text_classification/README.md
text_classification/README.md
+32
-0
text_classification/text_classification_cnn.py
text_classification/text_classification_cnn.py
+1
-1
text_classification/text_classification_dnn.py
text_classification/text_classification_dnn.py
+1
-1
未找到文件。
text_classification/README.md
浏览文件 @
344a9a43
...
...
@@ -126,6 +126,38 @@ def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128):
该CNN网络的输入数据类型和前面介绍过的DNN一致。
`paddle.networks.sequence_conv_pool`
为Paddle中已经封装好的带有池化的文本序列卷积模块,该模块的
`context_len`
参数用于指定卷积核在同一时间覆盖的文本长度,也即图2中的卷积核的高度;
`hidden_size`
用于指定该类型的卷积核的数量。可以看到,上述代码定义的结构中使用了128个大小为3的卷积核和128个大小为4的卷积核,这些卷积的结果经过最大池化和结果并置后产生一个256维的向量,向量经过一个全连接层输出最终预测结果。
## 自定义数据
上面的代码样例中使用的都是PaddlePaddle自带的样例数据,如果用户希望使用其他数据进行测试,需要自行编写数据读取接口。
编写数据读取接口的关键在于实现一个Python生成器,生成器负责解析数据文件中的每一行内容,并组合成适当的数据形式传送给网络中的data layer。例如在本样例中,data layer需要的数据类型为
`paddle.data_type.integer_value_sequence`
,这本质上是一个Python list。因此我们的生成器需要完成的主要就是“从文件中读取数据”和“转换成适当形式的Python list”这两件事。
假设我们的数据的内容形式为:
```
PaddlePaddle is good 1
What a terrible weather 0
```
每一行为一条样本,样本包括了原始语料和标签,语料内部的单词空格分隔,语料和标签之间用
`\t`
分隔。对于这样的数据我们可以如下编写数据读取接口:
```
python
def
encode_word
(
word
,
word_dict
):
if
word_dict
.
has_key
(
word
):
return
word_dict
[
word
]
else
:
return
word_dict
[
'<unk>'
]
def
data_reader
(
file_name
,
word_dict
):
def
data_reader
():
with
open
(
file_name
,
"r"
)
as
f
:
for
line
in
f
:
ins
,
label
=
line
.
strip
(
'
\n
'
).
split
(
'
\t
'
)
ins_data
=
[
int
(
encode_word
(
w
,
word_dict
))
for
w
in
ins
.
split
(
' '
)]
yield
ins_data
,
int
(
label
)
return
data_reader
```
其中
`word_dict`
为事先准备好的将单词映射为id的词表。该
`data_reader`
可以替换代码中原先的
`Paddle.dataset.imdb.train`
用以数据提供。
## 运行与输出
本部分以上文介绍的DNN网络为例,介绍如何利用样例中的
`text_classification_dnn.py`
脚本进行DNN网络的训练和对新样本的预测。
...
...
text_classification/text_classification_cnn.py
浏览文件 @
344a9a43
...
...
@@ -112,7 +112,7 @@ def cnn_infer(file_name):
if
__name__
==
"__main__"
:
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
4
)
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
num_pass
=
5
train_cnn_model
(
num_pass
=
num_pass
)
param_file_name
=
"cnn_params_pass"
+
str
(
num_pass
-
1
)
+
".tar.gz"
...
...
text_classification/text_classification_dnn.py
浏览文件 @
344a9a43
...
...
@@ -129,7 +129,7 @@ def dnn_infer(file_name):
if
__name__
==
"__main__"
:
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
4
)
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
num_pass
=
5
train_dnn_model
(
num_pass
=
num_pass
)
param_file_name
=
"dnn_params_pass"
+
str
(
num_pass
-
1
)
+
".tar.gz"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录