Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
cb0d224e
B
book
项目概览
PaddlePaddle
/
book
通知
17
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cb0d224e
编写于
3月 06, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
V2 API train for understand_sentiment
上级
b2a2fbbc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
323 addition
and
310 deletion
+323
-310
understand_sentiment/README.md
understand_sentiment/README.md
+160
-310
understand_sentiment/train.py
understand_sentiment/train.py
+163
-0
未找到文件。
understand_sentiment/README.md
浏览文件 @
cb0d224e
...
...
@@ -93,264 +93,175 @@ $$ h_t=Recrurent(x_t,h_{t-1})$$
<img
src=
"image/stacked_lstm.jpg"
width=
450
><br/>
图4. 栈式双向LSTM用于文本分类
</p>
## 数据准备
### 数据介绍与下载
我们以
[
IMDB情感分析数据集
](
http://ai.stanford.edu/%7Eamaas/data/sentiment/
)
为例进行介绍。IMDB数据集的训练集和测试集分别包含25000个已标注过的电影评论。其中,负面评论的得分小于等于4,正面评论的得分大于等于7,满分10分。您可以使用下面的脚本下载 IMDB 数椐集和
[
Moses
](
http://www.statmt.org/moses/
)
工具:
```
bash
./data/get_imdb.sh
## 示例程序
### 数据集介绍
我们以
[
IMDB情感分析数据集
](
http://ai.stanford.edu/%7Eamaas/data/sentiment/
)
为例进行介绍。IMDB数据集的训练集和测试集分别包含25000个已标注过的电影评论。其中,负面评论的得分小于等于4,正面评论的得分大于等于7,满分10分。
```
text
aclImdb
|- test
|-- neg
|-- pos
|- train
|-- neg
|-- pos
```
如果数椐获取成功,您将在目录
```data```
中看到下面的文件:
Paddle在
`dataset/imdb.py`
中提实现了imdb数据集的自动下载和读取,并提供了读取字典、训练数据、测试数据等API。
```
aclImdb get_imdb.sh imdb mosesdecoder-master
import sys
import paddle.trainer_config_helpers.attrs as attrs
from paddle.trainer_config_helpers.poolings import MaxPooling
import paddle.v2 as paddle
```
*
aclImdb: 从外部网站上下载的原始数椐集。
*
imdb: 仅包含训练和测试数椐集。
*
mosesdecoder-master: Moses 工具。
### 数据预处理
我们使用的预处理脚本为
`preprocess.py`
。该脚本会调用Moses工具中的
`tokenizer.perl`
脚本来切分单词和标点符号,并会将训练集随机打乱排序再构建字典。注意:我们只使用已标注的训练集和测试集。执行下面的命令就可以预处理数椐:
```
data_dir="./data/imdb"
python preprocess.py -i $data_dir
```
运行成功后目录
`./data/pre-imdb`
结构如下:
```
dict.txt labels.list test.list test_part_000 train.list train_part_000
```
*
test
\_
part
\_
000 和 train
\_
part
\_
000: 所有标记的测试集和训练集,训练集已经随机打乱。
*
train.list 和 test.list: 训练集和测试集文件列表。
*
dict.txt: 利用训练集生成的字典。
*
labels.list: 类别标签列表,标签0表示负面评论,标签1表示正面评论。
### 提供数据给PaddlePaddle
PaddlePaddle可以读取Python写的传输数据脚本,下面
`dataprovider.py`
文件给出了完整例子,主要包括两部分:
*
hook: 定义文本信息、类别Id的数据类型。文本被定义为整数序列
`integer_value_sequence`
,类别被定义为整数
`integer_value`
。
*
process: 按行读取以
`'\t\t'`
分隔的类别ID和文本信息,并用yield关键字返回。
```
python
from
paddle.trainer.PyDataProvider2
import
*
def
hook
(
settings
,
dictionary
,
**
kwargs
):
settings
.
word_dict
=
dictionary
settings
.
input_types
=
{
'word'
:
integer_value_sequence
(
len
(
settings
.
word_dict
)),
'label'
:
integer_value
(
2
)
}
settings
.
logger
.
info
(
'dict len : %d'
%
(
len
(
settings
.
word_dict
)))
@
provider
(
init_hook
=
hook
)
def
process
(
settings
,
file_name
):
with
open
(
file_name
,
'r'
)
as
fdata
:
for
line_count
,
line
in
enumerate
(
fdata
):
label
,
comment
=
line
.
strip
().
split
(
'
\t\t
'
)
label
=
int
(
label
)
words
=
comment
.
split
()
word_slot
=
[
settings
.
word_dict
[
w
]
for
w
in
words
if
w
in
settings
.
word_dict
]
yield
{
'word'
:
word_slot
,
'label'
:
label
}
```
## 模型配置说明
`trainer_config.py`
是一个配置文件的例子。
### 数据定义
```
python
from
os.path
import
join
as
join_path
from
paddle.trainer_config_helpers
import
*
# 是否是测试模式
is_test
=
get_config_arg
(
'is_test'
,
bool
,
False
)
# 是否是预测模式
is_predict
=
get_config_arg
(
'is_predict'
,
bool
,
False
)
# 数据路径
data_dir
=
"./data/pre-imdb"
# 文件名
train_list
=
"train.list"
test_list
=
"test.list"
dict_file
=
"dict.txt"
# 字典大小
dict_dim
=
len
(
open
(
join_path
(
data_dir
,
"dict.txt"
)).
readlines
())
# 类别个数
class_dim
=
len
(
open
(
join_path
(
data_dir
,
'labels.list'
)).
readlines
())
if
not
is_predict
:
train_list
=
join_path
(
data_dir
,
train_list
)
test_list
=
join_path
(
data_dir
,
test_list
)
dict_file
=
join_path
(
data_dir
,
dict_file
)
train_list
=
train_list
if
not
is_test
else
None
# 构造字典
word_dict
=
dict
()
with
open
(
dict_file
,
'r'
)
as
f
:
for
i
,
line
in
enumerate
(
open
(
dict_file
,
'r'
)):
word_dict
[
line
.
split
(
'
\t
'
)[
0
]]
=
i
# 通过define_py_data_sources2函数从dataprovider.py中读取数据
define_py_data_sources2
(
train_list
,
test_list
,
module
=
"dataprovider"
,
obj
=
"process"
,
# 指定生成数据的函数。
args
=
{
'dictionary'
:
word_dict
})
# 额外的参数,这里指定词典。
## 配置模型
在该示例中,我们实现了两种文本分类算法,分别基于上文所述的
[
文本卷积神经网络
](
#文本卷积神经网络(CNN)
)
和
[
栈式双向LSTM
](
#栈式双向LSTM(Stacked
Bidirectional LSTM))。
### 文本卷积神经网络
```
### 算法配置
```
python
settings
(
batch_size
=
128
,
learning_rate
=
2e-3
,
learning_method
=
AdamOptimizer
(),
regularization
=
L2Regularization
(
8e-4
),
gradient_clipping_threshold
=
25
)
```
*
设置batch size大小为128。
*
设置全局学习率。
*
使用adam优化。
*
设置L2正则。
*
设置梯度截断(clipping)阈值。
### 模型结构
我们用PaddlePaddle实现了两种文本分类算法,分别基于上文所述的
[
文本卷积神经网络
](
#文本卷积神经网络(CNN)
)
和
[
栈式双向LSTM
](
#栈式双向LSTM(Stacked
Bidirectional LSTM))。
#### 文本卷积神经网络的实现
```
python
def convolution_net(input_dim,
class_dim=2,
emb_dim=128,
hid_dim
=
128
,
is_predict
=
False
):
# 网络输入:id表示的词序列,词典大小为input_dim
data
=
data_layer
(
"word"
,
input_dim
)
# 将id表示的词序列映射为embedding序列
emb
=
embedding_layer
(
input
=
data
,
size
=
emb_dim
)
# 卷积及最大化池操作,卷积核窗口大小为3
conv_3
=
sequence_conv_pool
(
input
=
emb
,
context_len
=
3
,
hidden_size
=
hid_dim
)
# 卷积及最大化池操作,卷积核窗口大小为4
conv_4
=
sequence_conv_pool
(
input
=
emb
,
context_len
=
4
,
hidden_size
=
hid_dim
)
# 将conv_3和conv_4拼接起来输入给softmax分类,类别数为class_dim
output
=
fc_layer
(
input
=
[
conv_3
,
conv_4
],
size
=
class_dim
,
act
=
SoftmaxActivation
())
if
not
is_predict
:
lbl
=
data_layer
(
"label"
,
1
)
#网络输入:类别标签
outputs
(
classification_cost
(
input
=
output
,
label
=
lbl
))
else
:
outputs
(
output
)
hid_dim=128):
data = paddle.layer.data("word",
paddle.data_type.integer_value_sequence(input_dim))
emb = paddle.layer.embedding(input=data, size=emb_dim)
conv_3 = paddle.networks.sequence_conv_pool(
input=emb, context_len=3, hidden_size=hid_dim)
conv_4 = paddle.networks.sequence_conv_pool(
input=emb, context_len=4, hidden_size=hid_dim)
output = paddle.layer.fc(input=[conv_3, conv_4],
size=class_dim,
act=paddle.activation.Softmax())
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))
cost = paddle.layer.classification_cost(input=output, label=lbl)
return cost
```
网络的输入
`input_dim`
表示的是词典的大小,
`class_dim`
表示类别数。这里,我们使用
[
`sequence_conv_pool`
](
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py
)
API实现了卷积和池化操作。
### 栈式双向LSTM
```
其中,我们仅用一个
[
`sequence_conv_pool`
](
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py
)
方法就实现了卷积和池化操作,卷积核的数量为hidden_size参数。
#### 栈式双向LSTM的实现
```
python
def stacked_lstm_net(input_dim,
class_dim=2,
emb_dim=128,
hid_dim=512,
stacked_num
=
3
,
is_predict
=
False
):
# LSTM的层数stacked_num为奇数,确保最高层LSTM正向
stacked_num=3):
"""
A Wrapper for sentiment classification task.
This network uses bi-directional recurrent network,
consisting three LSTM layers. This configure is referred to
the paper as following url, but use fewer layrs.
http://www.aclweb.org/anthology/P15-1109
input_dim: here is word dictionary dimension.
class_dim: number of categories.
emb_dim: dimension of word embedding.
hid_dim: dimension of hidden layer.
stacked_num: number of stacked lstm-hidden layer.
"""
assert stacked_num % 2 == 1
# 设置神经网络层的属性
layer_attr
=
ExtraLayerAttribute
(
drop_rate
=
0.5
)
# 设置参数的属性
fc_para_attr
=
ParameterAttribute
(
learning_rate
=
1e-3
)
lstm_para_attr
=
ParameterAttribute
(
initial_std
=
0.
,
learning_rate
=
1.
)
para_attr
=
[
fc_para_attr
,
lstm_para_attr
]
bias_attr
=
ParameterAttribute
(
initial_std
=
0.
,
l2_rate
=
0.
)
# 激活函数
relu
=
ReluActivation
()
linear
=
LinearActivation
()
# 网络输入:id表示的词序列,词典大小为input_dim
data
=
data_layer
(
"word"
,
input_dim
)
# 将id表示的词序列映射为embedding序列
emb
=
embedding_layer
(
input
=
data
,
size
=
emb_dim
)
fc1
=
fc_layer
(
input
=
emb
,
size
=
hid_dim
,
act
=
linear
,
bias_attr
=
bias_attr
)
# 基于LSTM的循环神经网络
lstm1
=
lstmemory
(
layer_attr = attrs.ExtraLayerAttribute(drop_rate=0.5)
fc_para_attr = attrs.ParameterAttribute(learning_rate=1e-3)
lstm_para_attr = attrs.ParameterAttribute(initial_std=0., learning_rate=1.)
para_attr = [fc_para_attr, lstm_para_attr]
bias_attr = attrs.ParameterAttribute(initial_std=0., l2_rate=0.)
relu = paddle.activation.Relu()
linear = paddle.activation.Linear()
data = paddle.layer.data("word",
paddle.data_type.integer_value_sequence(input_dim))
emb = paddle.layer.embedding(input=data, size=emb_dim)
fc1 = paddle.layer.fc(input=emb,
size=hid_dim,
act=linear,
bias_attr=bias_attr)
lstm1 = paddle.layer.lstmemory(
input=fc1, act=relu, bias_attr=bias_attr, layer_attr=layer_attr)
# 由fc_layer和lstmemory构建深度为stacked_num的栈式双向LSTM
inputs = [fc1, lstm1]
for i in range(2, stacked_num + 1):
fc
=
fc_layer
(
input
=
inputs
,
size
=
hid_dim
,
act
=
linear
,
param_attr
=
para_attr
,
bias_attr
=
bias_attr
)
lstm
=
lstmemory
(
fc = paddle.layer.fc(input=inputs,
size=hid_dim,
act=linear,
param_attr=para_attr,
bias_attr=bias_attr)
lstm = paddle.layer.lstmemory(
input=fc,
# 奇数层正向,偶数层反向。
reverse=(i % 2) == 0,
act=relu,
bias_attr=bias_attr,
layer_attr=layer_attr)
inputs = [fc, lstm]
# 对最后一层fc_layer使用时间维度上的最大池化得到定长向量
fc_last
=
pooling_layer
(
input
=
inputs
[
0
],
pooling_type
=
MaxPooling
())
# 对最后一层lstmemory使用时间维度上的最大池化得到定长向量
lstm_last
=
pooling_layer
(
input
=
inputs
[
1
],
pooling_type
=
MaxPooling
())
# 将fc_last和lstm_last拼接起来输入给softmax分类,类别数为class_dim
output
=
fc_layer
(
input
=
[
fc_last
,
lstm_last
],
size
=
class_dim
,
act
=
SoftmaxActivation
(),
bias_attr
=
bias_attr
,
param_attr
=
para_attr
)
if
is_predict
:
outputs
(
output
)
else
:
outputs
(
classification_cost
(
input
=
output
,
label
=
data_layer
(
'label'
,
1
)))
```
我们的模型配置
`trainer_config.py`
默认使用
`stacked_lstm_net`
网络,如果要使用
`convolution_net`
,注释相应的行即可。
```
python
stacked_lstm_net
(
dict_dim
,
class_dim
=
class_dim
,
stacked_num
=
3
,
is_predict
=
is_predict
)
# convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict)
fc_last = paddle.layer.pooling(input=inputs[0], pooling_type=MaxPooling())
lstm_last = paddle.layer.pooling(input=inputs[1], pooling_type=MaxPooling())
output = paddle.layer.fc(input=[fc_last, lstm_last],
size=class_dim,
act=paddle.activation.Softmax(),
bias_attr=bias_attr,
param_attr=para_attr)
lbl = paddle.layer.data("label", paddle.data_type.integer_value(2))
cost = paddle.layer.classification_cost(input=output, label=lbl)
return cost
```
网络的输入
`stacked_num`
表示的是LSTM的层数,需要是奇数,确保最高层LSTM正向。Paddle里面是通过一个fc和一个lstmemory来实现基于LSTM的循环神经网络。
## 训练模型
通过
`paddle.trainer.SGD`
构造一个sgd trainer,并调用
`trainer.train`
训练获得模型。
```
python
paddle
.
init
(
use_gpu
=
True
,
trainer_count
=
4
)
# create trainer
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
adam_optimizer
)
trainer
.
train
(
reader
=
paddle
.
reader
.
batched
(
paddle
.
reader
.
shuffle
(
data_reader
(
train_file
,
dict_file
),
buf_size
=
4096
),
batch_size
=
128
),
event_handler
=
event_handler
,
reader_dict
=
{
'word'
:
0
,
'label'
:
1
},
num_passes
=
10
)
```
可以通过给train函数传递一个
`event_handler`
来获取每个batch或者每个pass结束的状态。比如构造如下一个
`event_handler`
可以在每100个batch结束后输出cost和error;再每个pass结束后调用
`trainer.test`
计算一遍测试集并获得当前模型在测试集上的error。
```
python
if __name__ == '__main__':
# init
paddle.init(use_gpu=False)
```
启动paddle程序,use_gpu=False表示用CPU训练,如果系统支持GPU也可以修改成True使用GPU训练。
### 训练数据
使用Paddle提供的数据集
`dataset.imdb`
中的API来读取训练数据。
```
print 'load dictionary...'
word_dict = paddle.dataset.imdb.word_dict()
dict_dim = len(word_dict)
class_dim = 2
```
加载数据字典,这里通过
`word_dict()`
API可以直接构造字典。
`class_dim`
是指样本类别数,该示例中样本只有正负两类。
```
train_reader = paddle.reader.batched(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100)
test_reader = paddle.reader.batched(
lambda: paddle.dataset.imdb.test(word_dict),
batch_size=100)
```
这里,
`dataset.imdb.train()`
和
`dataset.imdb.test()`
分别是
`dataset.imdb`
中的训练数据和测试数据API。
`train_reader`
在训练时使用,意义是将读取的训练数据进行shuffle后,组成一个batch数据。同理,
`test_reader`
是在测试的时候使用,将读取的测试数据组成一个batch。
```
reader_dict={'word': 0, 'label': 1}
```
`reader_dict`
用来指定
`train_reader`
和
`test_reader`
返回的数据与模型配置中data_layer的对应关系。这里表示reader返回的第0列数据对应
`word`
层,第1列数据对应
`label`
层。
### 构造模型
```
# Please choose the way to build the network
# by uncommenting the corresponding line.
cost = convolution_net(dict_dim, class_dim=class_dim)
# cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3)
```
该示例中默认使用
`convolution_net`
网络,如果使用
`stacked_lstm_net`
网络,注释相应的行即可。其中cost是网络的优化目标,同时cost包含了整个网络的拓扑信息。
### 网络参数
```
# create parameters
parameters = paddle.parameters.create(cost)
```
根据网络的拓扑构造网络参数。这里parameters是整个网络的参数集。
### 优化算法
```
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))
```
Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。
### 训练
可以通过
`paddle.trainer.SGD`
构造一个sgd trainer,并调用
`trainer.train`
来训练模型。
```
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
...
...
@@ -363,12 +274,31 @@ stacked_lstm_net(
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
data_reader
(
test_file
,
dict_file
),
batch_size
=
128
),
lambda: paddle.dataset.imdb.test(word_dict),
batch_size=128),
reader_dict={'word': 0,
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
```
程序运行之后的输入如下。
可以通过给train函数传递一个
`event_handler`
来获取每个batch和每个pass结束的状态。比如构造如下一个
`event_handler`
可以在每100个batch结束后输出cost和error;在每个pass结束后调用
`trainer.test`
计算一遍测试集并获得当前模型在测试集上的error。
```
# create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=adam_optimizer)
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100),
event_handler=event_handler,
reader_dict={'word': 0,
'label': 1},
num_passes=10)
```
程序运行之后的输出如下。
```
Pass 0, Batch 0, Cost 0.693721, {'classification_error_evaluator': 0.5546875}
...................................................................................................
...
...
@@ -377,86 +307,6 @@ Pass 0, Batch 100, Cost 0.294321, {'classification_error_evaluator': 0.1015625}
Test with Pass 0, {'classification_error_evaluator': 0.11432000249624252}
```
## 应用模型
### 测试
测试是指使用训练出的模型评估已标记的数据集。
```
./test.sh
```
测试脚本
`test.sh`
的内容如下,其中函数
`get_best_pass`
通过对分类错误率进行排序来获得最佳模型:
```
bash
function
get_best_pass
()
{
cat
$1
|
grep
-Pzo
'Test .*\n.*pass-.*'
|
\
sed
-r
'N;s/Test.* error=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g'
|
\
sort
|
head
-n
1
}
log
=
train.log
LOG
=
`
get_best_pass
$log
`
LOG
=(
${
LOG
}
)
evaluate_pass
=
"model_output/pass-
${
LOG
[1]
}
"
echo
'evaluating from pass '
$evaluate_pass
model_list
=
./model.list
touch
$model_list
|
echo
$evaluate_pass
>
$model_list
net_conf
=
trainer_config.py
paddle train
--config
=
$net_conf
\
--model_list
=
$model_list
\
--job
=
test
\
--use_gpu
=
false
\
--trainer_count
=
4
\
--config_args
=
is_test
=
1
\
2>&1 |
tee
'test.log'
```
与训练不同,测试时需要指定
`--job = test`
和模型路径
`--model_list = $model_list`
。如果测试成功,日志将保存在
`test.log`
中。 在我们的测试中,最好的模型是
`model_output/pass-00002`
,分类错误率是0.115645:
```
Pass=0 samples=24999 AvgCost=0.280471 Eval: classification_error_evaluator=0.115645
```
### 预测
`predict.py`
脚本提供了一个预测接口。预测IMDB中未标记评论的示例如下:
```
./predict.sh
```
predict.sh的内容如下(注意应该确保默认模型路径
`model_output/pass-00002`
存在或更改为其它模型路径):
```
bash
model
=
model_output/pass-00002/
config
=
trainer_config.py
label
=
data/pre-imdb/labels.list
cat
./data/aclImdb/test/pos/10007_10.txt | python predict.py
\
--tconf
=
$config
\
--model
=
$model
\
--label
=
$label
\
--dict
=
./data/pre-imdb/dict.txt
\
--batch_size
=
1
```
*
`cat ./data/aclImdb/test/pos/10007_10.txt`
: 输入预测样本。
*
`predict.py`
: 预测接口脚本。
*
`--tconf=$config`
: 设置网络配置。
*
`--model=$model`
: 设置模型路径。
*
`--label=$label`
: 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
*
`--dict=data/pre-imdb/dict.txt`
: 设置文本数据字典文件。
*
`--batch_size=1`
: 预测时的batch size大小。
本示例的预测结果:
```
Loading parameters from model_output/pass-00002/
predicting label is pos
```
`10007_10.txt`
在路径
`./data/aclImdb/test/pos`
下面,而这里预测的标签也是pos,说明预测正确。
## 总结
本章我们以情感分析为例,介绍了使用深度学习的方法进行端对端的短文本分类,并且使用PaddlePaddle完成了全部相关实验。同时,我们简要介绍了两种文本处理模型:卷积神经网络和循环神经网络。在后续的章节中我们会看到这两种基本的深度学习模型在其它任务上的应用。
## 参考文献
...
...
understand_sentiment/train.py
0 → 100644
浏览文件 @
cb0d224e
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
paddle.trainer_config_helpers.attrs
as
attrs
from
paddle.trainer_config_helpers.poolings
import
MaxPooling
import
paddle.v2
as
paddle
def
convolution_net
(
input_dim
,
class_dim
=
2
,
emb_dim
=
128
,
hid_dim
=
128
):
data
=
paddle
.
layer
.
data
(
"word"
,
paddle
.
data_type
.
integer_value_sequence
(
input_dim
))
emb
=
paddle
.
layer
.
embedding
(
input
=
data
,
size
=
emb_dim
)
conv_3
=
paddle
.
networks
.
sequence_conv_pool
(
input
=
emb
,
context_len
=
3
,
hidden_size
=
hid_dim
)
conv_4
=
paddle
.
networks
.
sequence_conv_pool
(
input
=
emb
,
context_len
=
4
,
hidden_size
=
hid_dim
)
output
=
paddle
.
layer
.
fc
(
input
=
[
conv_3
,
conv_4
],
size
=
class_dim
,
act
=
paddle
.
activation
.
Softmax
())
lbl
=
paddle
.
layer
.
data
(
"label"
,
paddle
.
data_type
.
integer_value
(
2
))
cost
=
paddle
.
layer
.
classification_cost
(
input
=
output
,
label
=
lbl
)
return
cost
def
stacked_lstm_net
(
input_dim
,
class_dim
=
2
,
emb_dim
=
128
,
hid_dim
=
512
,
stacked_num
=
3
):
"""
A Wrapper for sentiment classification task.
This network uses bi-directional recurrent network,
consisting three LSTM layers. This configure is referred to
the paper as following url, but use fewer layrs.
http://www.aclweb.org/anthology/P15-1109
input_dim: here is word dictionary dimension.
class_dim: number of categories.
emb_dim: dimension of word embedding.
hid_dim: dimension of hidden layer.
stacked_num: number of stacked lstm-hidden layer.
"""
assert
stacked_num
%
2
==
1
layer_attr
=
attrs
.
ExtraLayerAttribute
(
drop_rate
=
0.5
)
fc_para_attr
=
attrs
.
ParameterAttribute
(
learning_rate
=
1e-3
)
lstm_para_attr
=
attrs
.
ParameterAttribute
(
initial_std
=
0.
,
learning_rate
=
1.
)
para_attr
=
[
fc_para_attr
,
lstm_para_attr
]
bias_attr
=
attrs
.
ParameterAttribute
(
initial_std
=
0.
,
l2_rate
=
0.
)
relu
=
paddle
.
activation
.
Relu
()
linear
=
paddle
.
activation
.
Linear
()
data
=
paddle
.
layer
.
data
(
"word"
,
paddle
.
data_type
.
integer_value_sequence
(
input_dim
))
emb
=
paddle
.
layer
.
embedding
(
input
=
data
,
size
=
emb_dim
)
fc1
=
paddle
.
layer
.
fc
(
input
=
emb
,
size
=
hid_dim
,
act
=
linear
,
bias_attr
=
bias_attr
)
lstm1
=
paddle
.
layer
.
lstmemory
(
input
=
fc1
,
act
=
relu
,
bias_attr
=
bias_attr
,
layer_attr
=
layer_attr
)
inputs
=
[
fc1
,
lstm1
]
for
i
in
range
(
2
,
stacked_num
+
1
):
fc
=
paddle
.
layer
.
fc
(
input
=
inputs
,
size
=
hid_dim
,
act
=
linear
,
param_attr
=
para_attr
,
bias_attr
=
bias_attr
)
lstm
=
paddle
.
layer
.
lstmemory
(
input
=
fc
,
reverse
=
(
i
%
2
)
==
0
,
act
=
relu
,
bias_attr
=
bias_attr
,
layer_attr
=
layer_attr
)
inputs
=
[
fc
,
lstm
]
fc_last
=
paddle
.
layer
.
pooling
(
input
=
inputs
[
0
],
pooling_type
=
MaxPooling
())
lstm_last
=
paddle
.
layer
.
pooling
(
input
=
inputs
[
1
],
pooling_type
=
MaxPooling
())
output
=
paddle
.
layer
.
fc
(
input
=
[
fc_last
,
lstm_last
],
size
=
class_dim
,
act
=
paddle
.
activation
.
Softmax
(),
bias_attr
=
bias_attr
,
param_attr
=
para_attr
)
lbl
=
paddle
.
layer
.
data
(
"label"
,
paddle
.
data_type
.
integer_value
(
2
))
cost
=
paddle
.
layer
.
classification_cost
(
input
=
output
,
label
=
lbl
)
return
cost
if
__name__
==
'__main__'
:
# init
paddle
.
init
(
use_gpu
=
False
)
#data
print
'load dictionary...'
word_dict
=
paddle
.
dataset
.
imdb
.
word_dict
()
dict_dim
=
len
(
word_dict
)
class_dim
=
2
train_reader
=
paddle
.
reader
.
batched
(
paddle
.
reader
.
shuffle
(
lambda
:
paddle
.
dataset
.
imdb
.
train
(
word_dict
),
buf_size
=
1000
),
batch_size
=
100
)
test_reader
=
paddle
.
reader
.
batched
(
lambda
:
paddle
.
dataset
.
imdb
.
test
(
word_dict
),
batch_size
=
100
)
reader_dict
=
{
'word'
:
0
,
'label'
:
1
}
# network config
# Please choose the way to build the network
# by uncommenting the corresponding line.
cost
=
convolution_net
(
dict_dim
,
class_dim
=
class_dim
)
# cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3)
# create parameters
parameters
=
paddle
.
parameters
.
create
(
cost
)
# create optimizer
adam_optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
2e-3
,
regularization
=
paddle
.
optimizer
.
L2Regularization
(
rate
=
8e-4
),
model_average
=
paddle
.
optimizer
.
ModelAverage
(
average_window
=
0.5
))
# End batch and end pass event handler
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
print
"
\n
Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_reader
,
reader_dict
=
reader_dict
)
print
"
\n
Test with Pass %d, %s"
%
(
event
.
pass_id
,
result
.
metrics
)
# create trainer
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
adam_optimizer
)
trainer
.
train
(
reader
=
train_reader
,
event_handler
=
event_handler
,
reader_dict
=
reader_dict
,
num_passes
=
2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录