diff --git a/image_classification/README.md b/image_classification/README.md
index 4966b7e51111e04e6f79dc8afc087c52ae2d5170..47205147df37ccedae4025f37717fbd9fc4164cf 100644
--- a/image_classification/README.md
+++ b/image_classification/README.md
@@ -96,7 +96,7 @@ NIN模型主要有两个特点:1) 引入了多层感知卷积网络(Multi-Laye
Inception模块如下图7所示,图(a)是最简单的设计,输出是3个卷积层和一个池化层的特征拼接,这样设计的缺点是池化层不会改变特征通道数,拼接后会导致特征的通道数较大,经过几层这样的模块堆积会导致通道数越来越大,参数和计算量随之增大。为了改善这个缺点,图(b)引入3个1x1卷积层进行降维,所谓的降维就是减少通道数,同时如NIN模型中提到的1x1卷积也可以修正线性特征。
-
+
图7. Inception模块
diff --git a/understand_sentiment/.gitignore b/understand_sentiment/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..667762d327cb160376a4119fa9df9db41b6443b2
--- /dev/null
+++ b/understand_sentiment/.gitignore
@@ -0,0 +1,10 @@
+data/aclImdb
+data/imdb
+data/pre-imdb
+data/mosesdecoder-master
+*.log
+model_output
+dataprovider_copy_1.py
+model.list
+*.pyc
+.DS_Store
diff --git a/understand_sentiment/README.md b/understand_sentiment/README.md
index e420eb556ab35772b0f11c5dda9a7f2538c0c969..f3527427a042d93b27f00b8bee94fcfd9d5b4345 100644
--- a/understand_sentiment/README.md
+++ b/understand_sentiment/README.md
@@ -1 +1,470 @@
-TODO: Write about https://github.com/PaddlePaddle/Paddle/tree/develop/demo/sentiment
+# 情感分析
+## 背景介绍
+在自然语言处理中,情感分析一般是指判断一段文本所表达的情绪状态。其中,一段文本可以是一个句子,一个段落或一个文档。情绪状态可以是两类,如(正面,负面),(高兴,悲伤);也可以是三类,如(积极,消极,中性)等等。情感分析的应用场景十分广泛,如把用户在购物网站(亚马逊、天猫、淘宝等)、旅游网站、电影评论网站上发表的评论分成正面评论和负面评论;或为了分析用户对于某一产品的整体使用感受,抓取产品的用户评论并进行情感分析等等。表格1展示了对电影评论进行情感分析的例子:
+
+| 电影评论 | 类别 |
+| -------- | ----- |
+| 在冯小刚这几年的电影里,算最好的一部的了| 正面 |
+| 很不好看,好像一个地方台的电视剧 | 负面 |
+| 圆方镜头全程炫技,色调背景美则美矣,但剧情拖沓,口音不伦不类,一直努力却始终无法入戏| 负面|
+|剧情四星。但是圆镜视角加上婺源的风景整个非常有中国写意山水画的感觉,看得实在太舒服了。。|正面|
+
+表格 1 电影评论情感分析
+
+在自然语言处理中,情感分析属于典型的**文本分类**问题,即把需要进行情感分析的文本划分为其所属类别。文本分类涉及文本表示和分类方法两个问题。在深度学习的方法出现之前,主流的文本表示方法为词袋模型BOW(bag of words),话题模型等等;分类方法有SVM(support vector machine), LR(logistic regression)等等。
+
+对于一段文本,BOW表示会忽略其词顺序、语法和句法,将这段文本仅仅看做是一个词集合,因此BOW方法并不能充分表示文本的语义信息。例如,句子“这部电影糟糕透了”和“一个乏味,空洞,没有内涵的作品”在情感分析中具有很高的语义相似度,但是它们的BOW表示的相似度为0。又如,句子“一个空洞,没有内涵的作品”和“一个不空洞而且有内涵的作品”的BOW相似度很高,但实际上它们的意思很不一样。
+
+本章我们所要介绍的深度学习模型克服了BOW表示的上述缺陷,它在考虑词顺序的基础上把文本映射到低维度的语义空间,并且以端对端(end to end)的方式进行文本表示及分类,其性能相对于传统方法有显著的提升\[[1](#参考文献)\]。
+## 模型概览
+本章所使用的文本表示模型为卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)及其扩展。下面依次介绍这几个模型。
+### 文本卷积神经网络(CNN)
+卷积神经网络经常用来处理具有类似网格拓扑结构(grid-like topology)的数据。例如,图像可以视为二维网格的像素点,自然语言可以视为一维的词序列。卷积神经网络可以提取多种局部特征,并对其进行组合抽象得到更高级的特征表示。实验表明,卷积神经网络能高效地对图像及文本问题进行建模处理。
+
+卷积神经网络主要由卷积(convolution)和池化(pooling)操作构成,其应用及组合方式灵活多变,种类繁多。本小结我们以一种简单的文本分类卷积神经网络为例进行讲解\[[1](#参考文献)\],如图1所示:
+
+
+图1. 卷积神经网络文本分类模型
+
+假设待处理句子的长度为$n$,其中第$i$个词的词向量(word embedding)为$x_i\in\mathbb{R}^k$,$k$为维度大小。
+
+首先,进行词向量的拼接操作:将每$h$个词拼接起来形成一个大小为$h$的词窗口,记为$x_{i:i+h-1}$,它表示词序列$x_{i},x_{i+1},\ldots,x_{i+h-1}$的拼接,其中,$i$表示词窗口中第一个词在整个句子中的位置,取值范围从$1$到$n-h+1$,$x_{i:i+h-1}\in\mathbb{R}^{hk}$。
+
+其次,进行卷积操作:把卷积核(kernel)$w\in\mathbb{R}^{hk}$应用于包含$h$个词的窗口$x_{i:i+h-1}$,得到特征$c_i=f(w\cdot x_{i:i+h-1}+b)$,其中$b\in\mathbb{R}$为偏置项(bias),$f$为非线性激活函数,如$sigmoid$。将卷积核应用于句子中所有的词窗口${x_{1:h},x_{2:h+1},\ldots,x_{n-h+1:n}}$,产生一个特征图(feature map):
+
+$$c=[c_1,c_2,\ldots,c_{n-h+1}], c \in \mathbb{R}^{n-h+1}$$
+
+接下来,对特征图采用时间维度上的最大池化(max pooling over time)操作得到此卷积核对应的整句话的特征$\hat c$,它是特征图中所有元素的最大值:
+
+$$\hat c=max(c)$$
+
+在实际应用中,我们会使用多个卷积核来处理句子,窗口大小相同的卷积核堆叠起来形成一个矩阵(上文中的单个卷积核参数$w$相当于矩阵的某一行),这样可以更高效的完成运算。另外,我们也可使用窗口大小不同的卷积核来处理句子(图1作为示意画了四个卷积核,不同颜色表示不同大小的卷积核操作)。
+
+最后,将所有卷积核得到的特征拼接起来即为文本的定长向量表示,对于文本分类问题,将其连接至softmax即构建出完整的模型。
+
+对于一般的短文本分类问题,上文所述的简单的文本卷积网络即可达到很高的正确率\[[1](#参考文献)\]。若想得到更抽象更高级的文本特征表示,可以构建深层文本卷积神经网络\[[2](#参考文献),[3](#参考文献)\]。
+### 循环神经网络(RNN)
+循环神经网络是一种能对序列数据进行精确建模的有力工具。实际上,循环神经网络的理论计算能力是图灵完备的\[[4](#参考文献)\]。自然语言是一种典型的序列数据(词序列),近年来,循环神经网络及其变体(如long short term memory\[[5](#参考文献)\]等)在自然语言处理的多个领域,如语言模型、句法解析、语义角色标注(或一般的序列标注)、语义表示、图文生成、对话、机器翻译等任务上均表现优异甚至成为目前效果最好的方法。
+
+
+图2. 循环神经网络按时间展开的示意图
+
+循环神经网络按时间展开后如图2所示:在第$t$时刻,网络读入第$t$个输入$x_t$(向量表示)及前一时刻隐层的状态值$h_{t-1}$(向量表示,$h_0$一般初始化为$0$向量),计算得出本时刻隐层的状态值$h_t$,重复这一步骤直至读完所有输入。如果将循环神经网络所表示的函数记为$f$,则其公式可表示为:
+
+$$h_t=f(x_t,h_{t-1})=\sigma(W_{xh}x_t+W_{hh}h_{h-1}+b_h)$$
+
+其中$W_{xh}$是输入到隐层的矩阵参数,$W_{hh}$是隐层到隐层的矩阵参数,$b_h$为隐层的偏置向量(bias)参数,$\sigma$为$sigmoid$函数。
+
+在处理自然语言时,一般会先将词(one-hot表示)映射为其词向量(word embedding)表示,然后再作为循环神经网络每一时刻的输入$x_t$。此外,可以根据实际需要的不同在循环神经网络的隐层上连接其它层。如,可以把一个循环神经网络的隐层输出连接至下一个循环神经网络的输入构建深层(deep or stacked)循环神经网络,或者提取最后一个时刻的隐层状态作为句子表示进而使用分类模型等等。
+
+### 长短期记忆网络(LSTM)
+对于较长的序列数据,循环神经网络的训练过程中容易出现梯度消失或爆炸现象\[[6](#参考文献)\]。为了解决这一问题,Hochreiter S, Schmidhuber J. (1997)提出了LSTM(long short term memory\[[5](#参考文献)\])。
+
+相比于简单的循环神经网络,LSTM增加了记忆单元$c$、输入门$i$、遗忘门$f$及输出门$o$。这些门及记忆单元组合起来大大提升了循环神经网络处理长序列数据的能力。若将基于LSTM的循环神经网络表示的函数记为$F$,则其公式为:
+
+$$ h_t=F(x_t,h_{t-1})$$
+
+$F$由下列公式组合而成\[[7](#参考文献)\]:
+\begin{align}
+i_t & = \sigma(W_{xi}x_t+W_{hi}h_{h-1}+W_{ci}c_{t-1}+b_i)\\\\
+f_t & = \sigma(W_{xf}x_t+W_{hf}h_{h-1}+W_{cf}c_{t-1}+b_f)\\\\
+c_t & = f_t\odot c_{t-1}+i_t\odot tanh(W_{xc}x_t+W_{hc}h_{h-1}+b_c)\\\\
+o_t & = \sigma(W_{xo}x_t+W_{ho}h_{h-1}+W_{co}c_{t}+b_o)\\\\
+h_t & = o_t\odot tanh(c_t)\\\\
+\end{align}
+其中,$i_t, f_t, c_t, o_t$分别表示输入门,遗忘门,记忆单元及输出门的向量值,带角标的$W$及$b$为模型参数,$tanh$为双曲正切函数,$\odot$表示逐元素(elementwise)的乘法操作。输入门控制着新输入进入记忆单元$c$的强度,遗忘门控制着记忆单元维持上一时刻值的强度,输出门控制着输出记忆单元的强度。三种门的计算方式类似,但有着完全不同的参数,它们各自以不同的方式控制着记忆单元$c$,如图3所示:
+
+
+图3. 时刻$t$的LSTM\[[7](#参考文献)\]
+
+LSTM通过给简单的循环神经网络增加记忆及控制门的方式,增强了其处理远距离依赖问题的能力。类似原理的改进还有Gated Recurrent Unit (GRU)\[[8](#参考文献)\],其设计更为简洁一些。**这些改进虽然各有不同,但是它们的宏观描述却与简单的循环神经网络一样(如图2所示),即隐状态依据当前输入及前一时刻的隐状态来改变,不断地循环这一过程直至输入处理完毕:**
+
+$$ h_t=Recrurent(x_t,h_{t-1})$$
+
+其中,$Recrurent$可以表示简单的循环神经网络、GRU或LSTM。
+### 栈式双向LSTM(Stacked Bidirectional LSTM)
+对于正常顺序的循环神经网络,$h_t$包含了$t$时刻之前的输入信息,也就是上文信息。同样,为了得到下文信息,我们可以使用反方向(将输入逆序处理)的循环神经网络。结合构建深层循环神经网络的方法(深层神经网络往往能得到更抽象和高级的特征表示),我们可以通过构建更加强有力的基于LSTM的栈式双向循环神经网络\[[9](#参考文献)\],来对时序数据进行建模。
+
+如图4所示(以三层为例),奇数层LSTM正向,偶数层LSTM反向,高一层的LSTM使用低一层LSTM及之前所有层的信息作为输入,对最高层LSTM序列使用时间维度上的最大池化即可得到文本的定长向量表示(这一表示充分融合了文本的上下文信息,并且对文本进行了深层次抽象),最后我们将文本表示连接至softmax构建分类模型。
+
+
+图4. 栈式双向LSTM用于文本分类
+
+## 数据准备
+### 数据介绍与下载
+我们以[IMDB情感分析数据集](http://ai.stanford.edu/%7Eamaas/data/sentiment/)为例进行介绍。IMDB数据集的训练集和测试集分别包含25000个已标注过的电影评论。其中,负面评论的得分小于等于4,正面评论的得分大于等于7,满分10分。您可以使用下面的脚本下载 IMDB 数椐集和[Moses](http://www.statmt.org/moses/)工具:
+
+```bash
+./data/get_imdb.sh
+```
+如果数椐获取成功,您将在目录```data```中看到下面的文件:
+
+```
+aclImdb get_imdb.sh imdb mosesdecoder-master
+```
+
+* aclImdb: 从外部网站上下载的原始数椐集。
+* imdb: 仅包含训练和测试数椐集。
+* mosesdecoder-master: Moses 工具。
+
+### 数据预处理
+我们使用的预处理脚本为`preprocess.py`。该脚本会调用Moses工具中的`tokenizer.perl`脚本来切分单词和标点符号,并会将训练集随机打乱排序再构建字典。注意:我们只使用已标注的训练集和测试集。执行下面的命令就可以预处理数椐:
+
+```
+data_dir="./data/imdb"
+python preprocess.py -i $data_dir
+```
+
+运行成功后目录`./data/pre-imdb` 结构如下:
+
+```
+dict.txt labels.list test.list test_part_000 train.list train_part_000
+```
+
+* test\_part\_000 和 train\_part\_000: 所有标记的测试集和训练集,训练集已经随机打乱。
+* train.list 和 test.list: 训练集和测试集文件列表。
+* dict.txt: 利用训练集生成的字典。
+* labels.list: 类别标签列表,标签0表示负面评论,标签1表示正面评论。
+
+### 提供数据给PaddlePaddle
+PaddlePaddle可以读取Python写的传输数据脚本,下面`dataprovider.py`文件给出了完整例子,主要包括两部分:
+
+* hook: 定义文本信息、类别Id的数据类型。文本被定义为整数序列`integer_value_sequence`,类别被定义为整数`integer_value`。
+* process: 按行读取以`'\t\t'`分隔的类别ID和文本信息,并用yield关键字返回。
+
+```python
+from paddle.trainer.PyDataProvider2 import *
+
+def hook(settings, dictionary, **kwargs):
+ settings.word_dict = dictionary
+ settings.input_types = {
+ 'word': integer_value_sequence(len(settings.word_dict)),
+ 'label': integer_value(2)
+ }
+ settings.logger.info('dict len : %d' % (len(settings.word_dict)))
+
+
+@provider(init_hook=hook)
+def process(settings, file_name):
+ with open(file_name, 'r') as fdata:
+ for line_count, line in enumerate(fdata):
+ label, comment = line.strip().split('\t\t')
+ label = int(label)
+ words = comment.split()
+ word_slot = [
+ settings.word_dict[w] for w in words if w in settings.word_dict
+ ]
+ yield {
+ 'word': word_slot,
+ 'label': label
+ }
+```
+
+## 模型配置说明
+`trainer_config.py` 是一个配置文件的例子。
+### 数据定义
+```python
+from os.path import join as join_path
+from paddle.trainer_config_helpers import *
+# 是否是测试模式
+is_test = get_config_arg('is_test', bool, False)
+# 是否是预测模式
+is_predict = get_config_arg('is_predict', bool, False)
+
+# 数据路径
+data_dir = "./data/pre-imdb"
+# 文件名
+train_list = "train.list"
+test_list = "test.list"
+dict_file = "dict.txt"
+
+# 字典大小
+dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines())
+# 类别个数
+class_dim = len(open(join_path(data_dir, 'labels.list')).readlines())
+
+if not is_predict:
+ train_list = join_path(data_dir, train_list)
+ test_list = join_path(data_dir, test_list)
+ dict_file = join_path(data_dir, dict_file)
+ train_list = train_list if not is_test else None
+ # 构造字典
+ word_dict = dict()
+ with open(dict_file, 'r') as f:
+ for i, line in enumerate(open(dict_file, 'r')):
+ word_dict[line.split('\t')[0]] = i
+ # 通过define_py_data_sources2函数从dataprovider.py中读取数据
+ define_py_data_sources2(
+ train_list,
+ test_list,
+ module="dataprovider",
+ obj="process", # 指定生成数据的函数。
+ args={'dictionary': word_dict}) # 额外的参数,这里指定词典。
+```
+
+### 算法配置
+
+```python
+settings(
+ batch_size=128,
+ learning_rate=2e-3,
+ learning_method=AdamOptimizer(),
+ regularization=L2Regularization(8e-4),
+ gradient_clipping_threshold=25)
+```
+
+* 设置batch size大小为128。
+* 设置全局学习率。
+* 使用adam优化。
+* 设置L2正则。
+* 设置梯度截断(clipping)阈值。
+
+### 模型结构
+我们用PaddlePaddle实现了两种文本分类算法,分别基于上文所述的[文本卷积神经网络](#文本卷积神经网络(CNN))和[栈式双向LSTM](#栈式双向LSTM(Stacked Bidirectional LSTM))。
+#### 文本卷积神经网络的实现
+```python
+def convolution_net(input_dim,
+ class_dim=2,
+ emb_dim=128,
+ hid_dim=128,
+ is_predict=False):
+ # 网络输入:id表示的词序列,词典大小为input_dim
+ data = data_layer("word", input_dim)
+ # 将id表示的词序列映射为embedding序列
+ emb = embedding_layer(input=data, size=emb_dim)
+ # 卷积及最大化池操作,卷积核窗口大小为3
+ conv_3 = sequence_conv_pool(input=emb, context_len=3, hidden_size=hid_dim)
+ # 卷积及最大化池操作,卷积核窗口大小为4
+ conv_4 = sequence_conv_pool(input=emb, context_len=4, hidden_size=hid_dim)
+ # 将conv_3和conv_4拼接起来输入给softmax分类,类别数为class_dim
+ output = fc_layer(
+ input=[conv_3, conv_4], size=class_dim, act=SoftmaxActivation())
+
+ if not is_predict:
+ lbl = data_layer("label", 1) #网络输入:类别标签
+ outputs(classification_cost(input=output, label=lbl))
+ else:
+ outputs(output)
+```
+
+其中,我们仅用一个[`sequence_conv_pool`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py)方法就实现了卷积和池化操作,卷积核的数量为hidden_size参数。
+#### 栈式双向LSTM的实现
+
+```python
+def stacked_lstm_net(input_dim,
+ class_dim=2,
+ emb_dim=128,
+ hid_dim=512,
+ stacked_num=3,
+ is_predict=False):
+
+ # LSTM的层数stacked_num为奇数,确保最高层LSTM正向
+ assert stacked_num % 2 == 1
+ # 设置神经网络层的属性
+ layer_attr = ExtraLayerAttribute(drop_rate=0.5)
+ # 设置参数的属性
+ fc_para_attr = ParameterAttribute(learning_rate=1e-3)
+ lstm_para_attr = ParameterAttribute(initial_std=0., learning_rate=1.)
+ para_attr = [fc_para_attr, lstm_para_attr]
+ bias_attr = ParameterAttribute(initial_std=0., l2_rate=0.)
+ # 激活函数
+ relu = ReluActivation()
+ linear = LinearActivation()
+
+
+ # 网络输入:id表示的词序列,词典大小为input_dim
+ data = data_layer("word", input_dim)
+ # 将id表示的词序列映射为embedding序列
+ emb = embedding_layer(input=data, size=emb_dim)
+
+ fc1 = fc_layer(input=emb, size=hid_dim, act=linear, bias_attr=bias_attr)
+ # 基于LSTM的循环神经网络
+ lstm1 = lstmemory(
+ input=fc1, act=relu, bias_attr=bias_attr, layer_attr=layer_attr)
+
+ # 由fc_layer和lstmemory构建深度为stacked_num的栈式双向LSTM
+ inputs = [fc1, lstm1]
+ for i in range(2, stacked_num + 1):
+ fc = fc_layer(
+ input=inputs,
+ size=hid_dim,
+ act=linear,
+ param_attr=para_attr,
+ bias_attr=bias_attr)
+ lstm = lstmemory(
+ input=fc,
+ # 奇数层正向,偶数层反向。
+ reverse=(i % 2) == 0,
+ act=relu,
+ bias_attr=bias_attr,
+ layer_attr=layer_attr)
+ inputs = [fc, lstm]
+
+ # 对最后一层fc_layer使用时间维度上的最大池化得到定长向量
+ fc_last = pooling_layer(input=inputs[0], pooling_type=MaxPooling())
+ # 对最后一层lstmemory使用时间维度上的最大池化得到定长向量
+ lstm_last = pooling_layer(input=inputs[1], pooling_type=MaxPooling())
+ # 将fc_last和lstm_last拼接起来输入给softmax分类,类别数为class_dim
+ output = fc_layer(
+ input=[fc_last, lstm_last],
+ size=class_dim,
+ act=SoftmaxActivation(),
+ bias_attr=bias_attr,
+ param_attr=para_attr)
+
+ if is_predict:
+ outputs(output)
+ else:
+ outputs(classification_cost(input=output, label=data_layer('label', 1)))
+```
+
+我们的模型配置`trainer_config.py`默认使用`stacked_lstm_net`网络,如果要使用`convolution_net`,注释相应的行即可。
+```python
+stacked_lstm_net(
+ dict_dim, class_dim=class_dim, stacked_num=3, is_predict=is_predict)
+# convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict)
+```
+
+## 训练模型
+使用`train.sh`脚本可以开启本地的训练:
+
+```
+./train.sh
+```
+
+train.sh内容如下:
+
+```bash
+paddle train --config=trainer_config.py \
+ --save_dir=./model_output \
+ --job=train \
+ --use_gpu=false \
+ --trainer_count=4 \
+ --num_passes=10 \
+ --log_period=20 \
+ --dot_period=20 \
+ --show_parameter_stats_period=100 \
+ --test_all_data_in_one_period=1 \
+ 2>&1 | tee 'train.log'
+```
+
+* \--config=trainer_config.py: 设置模型配置。
+* \--save\_dir=./model_output: 设置输出路径以保存训练完成的模型。
+* \--job=train: 设置工作模式为训练。
+* \--use\_gpu=false: 使用CPU训练,如果您安装GPU版本的PaddlePaddle,并想使用GPU来训练可将此设置为true。
+* \--trainer\_count=4:设置线程数(或GPU个数)。
+* \--num\_passes=15: 设置pass,PaddlePaddle中的一个pass意味着对数据集中的所有样本进行一次训练。
+* \--log\_period=20: 每20个batch打印一次日志。
+* \--show\_parameter\_stats\_period=100: 每100个batch打印一次统计信息。
+* \--test\_all_data\_in\_one\_period=1: 每次测试都测试所有数据。
+
+如果运行成功,输出日志保存在 `train.log`中,模型保存在目录`model_output/`中。 输出日志说明如下:
+
+```
+Batch=20 samples=2560 AvgCost=0.681644 CurrentCost=0.681644 Eval: classification_error_evaluator=0.36875 CurrentEval: classification_error_evaluator=0.36875
+...
+Pass=0 Batch=196 samples=25000 AvgCost=0.418964 Eval: classification_error_evaluator=0.1922
+Test samples=24999 cost=0.39297 Eval: classification_error_evaluator=0.149406
+```
+
+* Batch=xx: 表示训练了xx个Batch。
+* samples=xx: 表示训练了xx个样本。
+* AvgCost=xx: 从第0个batch到当前batch的平均损失。
+* CurrentCost=xx: 最新log_period个batch的损失。
+* Eval: classification\_error\_evaluator=xx: 表示第0个batch到当前batch的分类错误。
+* CurrentEval: classification\_error\_evaluator: 最新log_period个batch的分类错误。
+* Pass=0: 通过所有训练集一次称为一个Pass。 0表示第一次经过训练集。
+
+
+## 应用模型
+### 测试
+
+测试是指使用训练出的模型评估已标记的数据集。
+
+```
+./test.sh
+```
+
+测试脚本`test.sh`的内容如下,其中函数`get_best_pass`通过对分类错误率进行排序来获得最佳模型:
+
+```bash
+function get_best_pass() {
+ cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \
+ sed -r 'N;s/Test.* error=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' | \
+ sort | head -n 1
+}
+
+log=train.log
+LOG=`get_best_pass $log`
+LOG=(${LOG})
+evaluate_pass="model_output/pass-${LOG[1]}"
+
+echo 'evaluating from pass '$evaluate_pass
+
+model_list=./model.list
+touch $model_list | echo $evaluate_pass > $model_list
+net_conf=trainer_config.py
+paddle train --config=$net_conf \
+ --model_list=$model_list \
+ --job=test \
+ --use_gpu=false \
+ --trainer_count=4 \
+ --config_args=is_test=1 \
+ 2>&1 | tee 'test.log'
+```
+
+与训练不同,测试时需要指定`--job = test`和模型路径`--model_list = $model_list`。如果测试成功,日志将保存在`test.log`中。 在我们的测试中,最好的模型是`model_output/pass-00002`,分类错误率是0.115645:
+
+```
+Pass=0 samples=24999 AvgCost=0.280471 Eval: classification_error_evaluator=0.115645
+```
+
+### 预测
+`predict.py`脚本提供了一个预测接口。预测IMDB中未标记评论的示例如下:
+
+```
+./predict.sh
+```
+predict.sh的内容如下(注意应该确保默认模型路径`model_output/pass-00002`存在或更改为其它模型路径):
+
+```bash
+model=model_output/pass-00002/
+config=trainer_config.py
+label=data/pre-imdb/labels.list
+cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
+ --tconf=$config \
+ --model=$model \
+ --label=$label \
+ --dict=./data/pre-imdb/dict.txt \
+ --batch_size=1
+```
+
+* `cat ./data/aclImdb/test/pos/10007_10.txt` : 输入预测样本。
+* `predict.py` : 预测接口脚本。
+* `--tconf=$config` : 设置网络配置。
+* `--model=$model` : 设置模型路径。
+* `--label=$label` : 设置标签类别字典,这个字典是整数标签和字符串标签的一个对应。
+* `--dict=data/pre-imdb/dict.txt` : 设置文本数据字典文件。
+* `--batch_size=1` : 预测时的batch size大小。
+
+
+本示例的预测结果:
+
+```
+Loading parameters from model_output/pass-00002/
+predicting label is pos
+```
+
+`10007_10.txt`在路径`./data/aclImdb/test/pos`下面,而这里预测的标签也是pos,说明预测正确。
+## 总结
+本章我们以情感分析为例,介绍了使用深度学习的方法进行端对端的短文本分类,并且使用PaddlePaddle完成了全部相关实验。同时,我们简要介绍了两种文本处理模型:卷积神经网络和循环神经网络。在后续的章节中我们会看到这两种基本的深度学习模型在其它任务上的应用。
+## 参考文献
+1. Kim Y. [Convolutional neural networks for sentence classification](http://arxiv.org/pdf/1408.5882)[J]. arXiv preprint arXiv:1408.5882, 2014.
+2. Kalchbrenner N, Grefenstette E, Blunsom P. [A convolutional neural network for modelling sentences](http://arxiv.org/pdf/1404.2188.pdf?utm_medium=App.net&utm_source=PourOver)[J]. arXiv preprint arXiv:1404.2188, 2014.
+3. Yann N. Dauphin, et al. [Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083v1.pdf)[J] arXiv preprint arXiv:1612.08083, 2016.
+4. Siegelmann H T, Sontag E D. [On the computational power of neural nets](http://research.cs.queensu.ca/home/akl/cisc879/papers/SELECTED_PAPERS_FROM_VARIOUS_SOURCES/05070215382317071.pdf)[C]//Proceedings of the fifth annual workshop on Computational learning theory. ACM, 1992: 440-449.
+5. Hochreiter S, Schmidhuber J. [Long short-term memory](http://web.eecs.utk.edu/~itamar/courses/ECE-692/Bobby_paper1.pdf)[J]. Neural computation, 1997, 9(8): 1735-1780.
+6. Bengio Y, Simard P, Frasconi P. [Learning long-term dependencies with gradient descent is difficult](http://www-dsi.ing.unifi.it/~paolo/ps/tnn-94-gradient.pdf)[J]. IEEE transactions on neural networks, 1994, 5(2): 157-166.
+7. Graves A. [Generating sequences with recurrent neural networks](http://arxiv.org/pdf/1308.0850)[J]. arXiv preprint arXiv:1308.0850, 2013.
+8. Cho K, Van Merriënboer B, Gulcehre C, et al. [Learning phrase representations using RNN encoder-decoder for statistical machine translation](http://arxiv.org/pdf/1406.1078)[J]. arXiv preprint arXiv:1406.1078, 2014.
+9. Zhou J, Xu W. [End-to-end learning of semantic role labeling using recurrent neural networks](http://www.aclweb.org/anthology/P/P15/P15-1109.pdf)[C]//Proceedings of the Annual Meeting of the Association for Computational Linguistics. 2015.
diff --git a/understand_sentiment/data/get_imdb.sh b/understand_sentiment/data/get_imdb.sh
new file mode 100755
index 0000000000000000000000000000000000000000..7600af6fbb900ee845702f1297779c1f0ed9bf84
--- /dev/null
+++ b/understand_sentiment/data/get_imdb.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+set -x
+
+DIR="$( cd "$(dirname "$0")" ; pwd -P )"
+cd $DIR
+
+#download the dataset
+echo "Downloading aclImdb..."
+#http://ai.stanford.edu/%7Eamaas/data/sentiment/
+wget http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz
+
+echo "Downloading mosesdecoder..."
+#https://github.com/moses-smt/mosesdecoder
+wget https://github.com/moses-smt/mosesdecoder/archive/master.zip
+
+#extract package
+echo "Unzipping..."
+tar -zxvf aclImdb_v1.tar.gz
+unzip master.zip
+
+#move train and test set to imdb_data directory
+#in order to process when traing
+mkdir -p imdb/train
+mkdir -p imdb/test
+
+cp -r aclImdb/train/pos/ imdb/train/pos
+cp -r aclImdb/train/neg/ imdb/train/neg
+
+cp -r aclImdb/test/pos/ imdb/test/pos
+cp -r aclImdb/test/neg/ imdb/test/neg
+
+#remove compressed package
+rm aclImdb_v1.tar.gz
+rm master.zip
+
+echo "Done."
diff --git a/understand_sentiment/dataprovider.py b/understand_sentiment/dataprovider.py
new file mode 100755
index 0000000000000000000000000000000000000000..976351ab7015fe136d270e97b0c767ac7fc63112
--- /dev/null
+++ b/understand_sentiment/dataprovider.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from paddle.trainer.PyDataProvider2 import *
+
+
+def hook(settings, dictionary, **kwargs):
+ settings.word_dict = dictionary
+ settings.input_types = {
+ 'word': integer_value_sequence(len(settings.word_dict)),
+ 'label': integer_value(2)
+ }
+ settings.logger.info('dict len : %d' % (len(settings.word_dict)))
+
+
+@provider(init_hook=hook)
+def process(settings, file_name):
+ with open(file_name, 'r') as fdata:
+ for line_count, line in enumerate(fdata):
+ label, comment = line.strip().split('\t\t')
+ label = int(label)
+ words = comment.split()
+ word_slot = [
+ settings.word_dict[w] for w in words if w in settings.word_dict
+ ]
+ yield {'word': word_slot, 'label': label}
diff --git a/understand_sentiment/image/lstm.png b/understand_sentiment/image/lstm.png
new file mode 100644
index 0000000000000000000000000000000000000000..376afe66c61fe53d4ba383823ac685f4d6941161
Binary files /dev/null and b/understand_sentiment/image/lstm.png differ
diff --git a/understand_sentiment/image/rnn.png b/understand_sentiment/image/rnn.png
new file mode 100755
index 0000000000000000000000000000000000000000..1ab1fa974c8aadaf50a15d220f047b243d4d4c1a
Binary files /dev/null and b/understand_sentiment/image/rnn.png differ
diff --git a/understand_sentiment/image/stacked_lstm.jpg b/understand_sentiment/image/stacked_lstm.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4239055050966e0095e188a8c81d860711bce29d
Binary files /dev/null and b/understand_sentiment/image/stacked_lstm.jpg differ
diff --git a/understand_sentiment/image/text_cnn.png b/understand_sentiment/image/text_cnn.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb31c3fd6f68b83390591b306aff8f8f021ec497
Binary files /dev/null and b/understand_sentiment/image/text_cnn.png differ
diff --git a/understand_sentiment/predict.py b/understand_sentiment/predict.py
new file mode 100755
index 0000000000000000000000000000000000000000..8ec490f64691924013200a3d0038d39aa834b038
--- /dev/null
+++ b/understand_sentiment/predict.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+import numpy as np
+from optparse import OptionParser
+from py_paddle import swig_paddle, DataProviderConverter
+from paddle.trainer.PyDataProvider2 import integer_value_sequence
+from paddle.trainer.config_parser import parse_config
+"""
+Usage: run following command to show help message.
+ python predict.py -h
+"""
+
+
+class SentimentPrediction():
+ def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
+ """
+ train_conf: trainer configure.
+ dict_file: word dictionary file name.
+ model_dir: directory of model.
+ """
+ self.train_conf = train_conf
+ self.dict_file = dict_file
+ self.word_dict = {}
+ self.dict_dim = self.load_dict()
+ self.model_dir = model_dir
+ if model_dir is None:
+ self.model_dir = os.path.dirname(train_conf)
+
+ self.label = None
+ if label_file is not None:
+ self.load_label(label_file)
+
+ conf = parse_config(train_conf, "is_predict=1")
+ self.network = swig_paddle.GradientMachine.createFromConfigProto(
+ conf.model_config)
+ self.network.loadParameters(self.model_dir)
+ input_types = [integer_value_sequence(self.dict_dim)]
+ self.converter = DataProviderConverter(input_types)
+
+ def load_dict(self):
+ """
+ Load dictionary from self.dict_file.
+ """
+ for line_count, line in enumerate(open(self.dict_file, 'r')):
+ self.word_dict[line.strip().split('\t')[0]] = line_count
+ return len(self.word_dict)
+
+ def load_label(self, label_file):
+ """
+ Load label.
+ """
+ self.label = {}
+ for v in open(label_file, 'r'):
+ self.label[int(v.split('\t')[1])] = v.split('\t')[0]
+
+ def get_index(self, data):
+ """
+ transform word into integer index according to the dictionary.
+ """
+ words = data.strip().split()
+ word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
+ return word_slot
+
+ def batch_predict(self, data_batch):
+ input = self.converter(data_batch)
+ output = self.network.forwardTest(input)
+ prob = output[0]["value"]
+ labs = np.argsort(-prob)
+ for idx, lab in enumerate(labs):
+ if self.label is None:
+ print("predicting label is %d" % (lab[0]))
+ else:
+ print("predicting label is %s" % (self.label[lab[0]]))
+
+
+def option_parser():
+ usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
+ parser = OptionParser(usage="usage: %s [options]" % usage)
+ parser.add_option(
+ "-n",
+ "--tconf",
+ action="store",
+ dest="train_conf",
+ help="network config")
+ parser.add_option(
+ "-d",
+ "--dict",
+ action="store",
+ dest="dict_file",
+ help="dictionary file")
+ parser.add_option(
+ "-b",
+ "--label",
+ action="store",
+ dest="label",
+ default=None,
+ help="dictionary file")
+ parser.add_option(
+ "-c",
+ "--batch_size",
+ type="int",
+ action="store",
+ dest="batch_size",
+ default=1,
+ help="the batch size for prediction")
+ parser.add_option(
+ "-w",
+ "--model",
+ action="store",
+ dest="model_path",
+ default=None,
+ help="model path")
+ return parser.parse_args()
+
+
+def main():
+ options, args = option_parser()
+ train_conf = options.train_conf
+ batch_size = options.batch_size
+ dict_file = options.dict_file
+ model_path = options.model_path
+ label = options.label
+ swig_paddle.initPaddle("--use_gpu=0")
+ predict = SentimentPrediction(train_conf, dict_file, model_path, label)
+
+ batch = []
+ for line in sys.stdin:
+ batch.append([predict.get_index(line)])
+ if len(batch) == batch_size:
+ predict.batch_predict(batch)
+ batch = []
+ if len(batch) > 0:
+ predict.batch_predict(batch)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/understand_sentiment/predict.sh b/understand_sentiment/predict.sh
new file mode 100755
index 0000000000000000000000000000000000000000..20adee8a465ad2b78066dccd9efac2743f583350
--- /dev/null
+++ b/understand_sentiment/predict.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+#Note the default model is pass-00002, you shold make sure the model path
+#exists or change the mode path.
+model=model_output/pass-00002/
+config=trainer_config.py
+label=data/pre-imdb/labels.list
+cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
+ --tconf=$config \
+ --model=$model \
+ --label=$label \
+ --dict=./data/pre-imdb/dict.txt \
+ --batch_size=1
diff --git a/understand_sentiment/preprocess.py b/understand_sentiment/preprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..29b3682b747c66574590de5ea70574981cc536bb
--- /dev/null
+++ b/understand_sentiment/preprocess.py
@@ -0,0 +1,359 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import random
+import operator
+import numpy as np
+from subprocess import Popen, PIPE
+from os.path import join as join_path
+from optparse import OptionParser
+
+from paddle.utils.preprocess_util import *
+"""
+Usage: run following command to show help message.
+ python preprocess.py -h
+"""
+
+
+def save_dict(dict, filename, is_reverse=True):
+ """
+ Save dictionary into file.
+ dict: input dictionary.
+ filename: output file name, string.
+ is_reverse: True, descending order by value.
+ False, ascending order by value.
+ """
+ f = open(filename, 'w')
+ for k, v in sorted(dict.items(), key=operator.itemgetter(1),\
+ reverse=is_reverse):
+ f.write('%s\t%s\n' % (k, v))
+ f.close()
+
+
+def tokenize(sentences):
+ """
+ Use tokenizer.perl to tokenize input sentences.
+ tokenizer.perl is tool of Moses.
+ sentences : a list of input sentences.
+ return: a list of processed text.
+ """
+ dir = './data/mosesdecoder-master/scripts/tokenizer/tokenizer.perl'
+ tokenizer_cmd = [dir, '-l', 'en', '-q', '-']
+ assert isinstance(sentences, list)
+ text = "\n".join(sentences)
+ tokenizer = Popen(tokenizer_cmd, stdin=PIPE, stdout=PIPE)
+ tok_text, _ = tokenizer.communicate(text)
+ toks = tok_text.split('\n')[:-1]
+ return toks
+
+
+def read_lines(path):
+ """
+ path: String, file path.
+ return a list of sequence.
+ """
+ seqs = []
+ with open(path, 'r') as f:
+ for line in f.readlines():
+ line = line.strip()
+ if len(line):
+ seqs.append(line)
+ return seqs
+
+
+class SentimentDataSetCreate():
+ """
+ A class to process data for sentiment analysis task.
+ """
+
+ def __init__(self,
+ data_path,
+ output_path,
+ use_okenizer=True,
+ multi_lines=False):
+ """
+ data_path: string, traing and testing dataset path
+ output_path: string, output path, store processed dataset
+ multi_lines: whether a file has multi lines.
+ In order to shuffle fully, it needs to read all files into
+ memory, then shuffle them if one file has multi lines.
+ """
+ self.output_path = output_path
+ self.data_path = data_path
+
+ self.train_dir = 'train'
+ self.test_dir = 'test'
+
+ self.train_list = "train.list"
+ self.test_list = "test.list"
+
+ self.label_list = "labels.list"
+ self.classes_num = 0
+
+ self.batch_size = 50000
+ self.batch_dir = 'batches'
+
+ self.dict_file = "dict.txt"
+ self.dict_with_test = False
+ self.dict_size = 0
+ self.word_count = {}
+
+ self.tokenizer = use_okenizer
+ self.overwrite = False
+
+ self.multi_lines = multi_lines
+
+ self.train_dir = join_path(data_path, self.train_dir)
+ self.test_dir = join_path(data_path, self.test_dir)
+ self.train_list = join_path(output_path, self.train_list)
+ self.test_list = join_path(output_path, self.test_list)
+ self.label_list = join_path(output_path, self.label_list)
+ self.dict_file = join_path(output_path, self.dict_file)
+
+ def data_list(self, path):
+ """
+ create dataset from path
+ path: data path
+ return: data list
+ """
+ label_set = get_label_set_from_dir(path)
+ data = []
+ for lab_name in label_set.keys():
+ file_paths = list_files(join_path(path, lab_name))
+ for p in file_paths:
+ data.append({"label" : label_set[lab_name],\
+ "seq_path": p})
+ return data, label_set
+
+ def create_dict(self, data):
+ """
+ create dict for input data.
+ data: list, [sequence, sequnce, ...]
+ """
+ for seq in data:
+ for w in seq.strip().lower().split():
+ if w not in self.word_count:
+ self.word_count[w] = 1
+ else:
+ self.word_count[w] += 1
+
+ def create_dataset(self):
+ """
+ create file batches and dictionary of train data set.
+ If the self.overwrite is false and train.list already exists in
+ self.output_path, this function will not create and save file
+ batches from the data set path.
+ return: dictionary size, class number.
+ """
+ out_path = self.output_path
+ if out_path and not os.path.exists(out_path):
+ os.makedirs(out_path)
+
+ # If self.overwrite is false or self.train_list has existed,
+ # it will not process dataset.
+ if not (self.overwrite or not os.path.exists(self.train_list)):
+ print "%s already exists." % self.train_list
+ return
+
+ # Preprocess train data.
+ train_data, train_lab_set = self.data_list(self.train_dir)
+ print "processing train set..."
+ file_lists = self.save_data(train_data, "train", self.batch_size, True,
+ True)
+ save_list(file_lists, self.train_list)
+
+ # If have test data path, preprocess test data.
+ if os.path.exists(self.test_dir):
+ test_data, test_lab_set = self.data_list(self.test_dir)
+ assert (train_lab_set == test_lab_set)
+ print "processing test set..."
+ file_lists = self.save_data(test_data, "test", self.batch_size,
+ False, self.dict_with_test)
+ save_list(file_lists, self.test_list)
+
+ # save labels set.
+ save_dict(train_lab_set, self.label_list, False)
+ self.classes_num = len(train_lab_set.keys())
+
+ # save dictionary.
+ save_dict(self.word_count, self.dict_file, True)
+ self.dict_size = len(self.word_count)
+
+ def save_data(self,
+ data,
+ prefix="",
+ batch_size=50000,
+ is_shuffle=False,
+ build_dict=False):
+ """
+ Create batches for a Dataset object.
+ data: the Dataset object to process.
+ prefix: the prefix of each batch.
+ batch_size: number of data in each batch.
+ build_dict: whether to build dictionary for data
+
+ return: list of batch names
+ """
+ if is_shuffle and self.multi_lines:
+ return self.save_data_multi_lines(data, prefix, batch_size,
+ build_dict)
+
+ if is_shuffle:
+ random.shuffle(data)
+ num_batches = int(math.ceil(len(data) / float(batch_size)))
+ batch_names = []
+ for i in range(num_batches):
+ batch_name = join_path(self.output_path,
+ "%s_part_%03d" % (prefix, i))
+ begin = i * batch_size
+ end = min((i + 1) * batch_size, len(data))
+ # read a batch of data
+ label_list, data_list = self.get_data_list(begin, end, data)
+ if build_dict:
+ self.create_dict(data_list)
+ self.save_file(label_list, data_list, batch_name)
+ batch_names.append(batch_name)
+
+ return batch_names
+
+ def get_data_list(self, begin, end, data):
+ """
+ begin: int, begining index of data.
+ end: int, ending index of data.
+ data: a list of {"seq_path": seqquence path, "label": label index}
+
+ return a list of label and a list of sequence.
+ """
+ label_list = []
+ data_list = []
+ for j in range(begin, end):
+ seqs = read_lines(data[j]["seq_path"])
+ lab = int(data[j]["label"])
+ #File may have multiple lines.
+ for seq in seqs:
+ data_list.append(seq)
+ label_list.append(lab)
+ if self.tokenizer:
+ data_list = tokenize(data_list)
+ return label_list, data_list
+
+ def save_data_multi_lines(self,
+ data,
+ prefix="",
+ batch_size=50000,
+ build_dict=False):
+ """
+ In order to shuffle fully, there is no need to load all data if
+ each file only contains one sample, it only needs to shuffle list
+ of file name. But one file contains multi lines, each line is one
+ sample. It needs to read all data into memory to shuffle fully.
+ This interface is mainly for data containning multi lines in each
+ file, which consumes more memory if there is a great mount of data.
+
+ data: the Dataset object to process.
+ prefix: the prefix of each batch.
+ batch_size: number of data in each batch.
+ build_dict: whether to build dictionary for data
+
+ return: list of batch names
+ """
+ assert self.multi_lines
+ label_list = []
+ data_list = []
+
+ # read all data
+ label_list, data_list = self.get_data_list(0, len(data), data)
+ if build_dict:
+ self.create_dict(data_list)
+
+ length = len(label_list)
+ perm_list = np.array([i for i in xrange(length)])
+ random.shuffle(perm_list)
+
+ num_batches = int(math.ceil(length / float(batch_size)))
+ batch_names = []
+ for i in range(num_batches):
+ batch_name = join_path(self.output_path,
+ "%s_part_%03d" % (prefix, i))
+ begin = i * batch_size
+ end = min((i + 1) * batch_size, length)
+ sub_label = [label_list[perm_list[i]] for i in range(begin, end)]
+ sub_data = [data_list[perm_list[i]] for i in range(begin, end)]
+ self.save_file(sub_label, sub_data, batch_name)
+ batch_names.append(batch_name)
+
+ return batch_names
+
+ def save_file(self, label_list, data_list, filename):
+ """
+ Save data into file.
+ label_list: a list of int value.
+ data_list: a list of sequnece.
+ filename: output file name.
+ """
+ f = open(filename, 'w')
+ print "saving file: %s" % filename
+ for lab, seq in zip(label_list, data_list):
+ f.write('%s\t\t%s\n' % (lab, seq))
+ f.close()
+
+
+def option_parser():
+ parser = OptionParser(usage="usage: python preprcoess.py "\
+ "-i data_dir [options]")
+ parser.add_option(
+ "-i",
+ "--data",
+ action="store",
+ dest="input",
+ help="Input data directory.")
+ parser.add_option(
+ "-o",
+ "--output",
+ action="store",
+ dest="output",
+ default=None,
+ help="Output directory.")
+ parser.add_option(
+ "-t",
+ "--tokenizer",
+ action="store",
+ dest="use_tokenizer",
+ default=True,
+ help="Whether to use tokenizer.")
+ parser.add_option("-m", "--multi_lines", action="store",
+ dest="multi_lines", default=False,
+ help="If input text files have multi lines and they "\
+ "need to be shuffled, you should set -m True,")
+ return parser.parse_args()
+
+
+def main():
+ options, args = option_parser()
+ data_dir = options.input
+ output_dir = options.output
+ use_tokenizer = options.use_tokenizer
+ multi_lines = options.multi_lines
+ if output_dir is None:
+ outname = os.path.basename(options.input)
+ output_dir = join_path(os.path.dirname(data_dir), 'pre-' + outname)
+ data_creator = SentimentDataSetCreate(data_dir, output_dir, use_tokenizer,
+ multi_lines)
+ data_creator.create_dataset()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/understand_sentiment/test.sh b/understand_sentiment/test.sh
new file mode 100755
index 0000000000000000000000000000000000000000..8af827c3388c8df88a872bd87d121a4f9631c3ff
--- /dev/null
+++ b/understand_sentiment/test.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+function get_best_pass() {
+ cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \
+ sed -r 'N;s/Test.* classification_error_evaluator=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' |\
+ sort -n | head -n 1
+}
+
+log=train.log
+LOG=`get_best_pass $log`
+LOG=(${LOG})
+evaluate_pass="model_output/pass-${LOG[1]}"
+
+echo 'evaluating from pass '$evaluate_pass
+
+model_list=./model.list
+touch $model_list | echo $evaluate_pass > $model_list
+net_conf=trainer_config.py
+paddle train --config=$net_conf \
+ --model_list=$model_list \
+ --job=test \
+ --use_gpu=false \
+ --trainer_count=4 \
+ --config_args=is_test=1 \
+ 2>&1 | tee 'test.log'
diff --git a/understand_sentiment/train.sh b/understand_sentiment/train.sh
new file mode 100755
index 0000000000000000000000000000000000000000..df8d464d557edbc2f538cb492bd29e8d32c77635
--- /dev/null
+++ b/understand_sentiment/train.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+paddle train --config=trainer_config.py \
+ --save_dir=./model_output \
+ --job=train \
+ --use_gpu=false \
+ --trainer_count=4 \
+ --num_passes=10 \
+ --log_period=10 \
+ --dot_period=20 \
+ --show_parameter_stats_period=100 \
+ --test_all_data_in_one_period=1 \
+ 2>&1 | tee 'train.log'
diff --git a/understand_sentiment/trainer_config.py b/understand_sentiment/trainer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9b98634bda18e4659c9aeaa8eeffcc52c13e1c
--- /dev/null
+++ b/understand_sentiment/trainer_config.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from os.path import join as join_path
+from paddle.trainer_config_helpers import *
+# whether this config is used for test
+is_test = get_config_arg('is_test', bool, False)
+# whether this config is used for prediction
+is_predict = get_config_arg('is_predict', bool, False)
+
+data_dir = "./data/pre-imdb"
+train_list = "train.list"
+test_list = "test.list"
+dict_file = "dict.txt"
+
+dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines())
+class_dim = len(open(join_path(data_dir, 'labels.list')).readlines())
+
+if not is_predict:
+ train_list = join_path(data_dir, train_list)
+ test_list = join_path(data_dir, test_list)
+ dict_file = join_path(data_dir, dict_file)
+ train_list = train_list if not is_test else None
+ word_dict = dict()
+ with open(dict_file, 'r') as f:
+ for i, line in enumerate(open(dict_file, 'r')):
+ word_dict[line.split('\t')[0]] = i
+
+ define_py_data_sources2(
+ train_list,
+ test_list,
+ module="dataprovider",
+ obj="process",
+ args={'dictionary': word_dict})
+
+################## Algorithm Config #####################
+
+settings(
+ batch_size=128,
+ learning_rate=2e-3,
+ learning_method=AdamOptimizer(),
+ average_window=0.5,
+ regularization=L2Regularization(8e-4),
+ gradient_clipping_threshold=25)
+
+#################### Network Config ######################
+
+
+def convolution_net(input_dim,
+ class_dim=2,
+ emb_dim=128,
+ hid_dim=128,
+ is_predict=False):
+ data = data_layer("word", input_dim)
+ emb = embedding_layer(input=data, size=emb_dim)
+ conv_3 = sequence_conv_pool(input=emb, context_len=3, hidden_size=hid_dim)
+ conv_4 = sequence_conv_pool(input=emb, context_len=4, hidden_size=hid_dim)
+ output = fc_layer(
+ input=[conv_3, conv_4], size=class_dim, act=SoftmaxActivation())
+
+ if not is_predict:
+ lbl = data_layer("label", 1)
+ outputs(classification_cost(input=output, label=lbl))
+ else:
+ outputs(output)
+
+
+def stacked_lstm_net(input_dim,
+ class_dim=2,
+ emb_dim=128,
+ hid_dim=512,
+ stacked_num=3,
+ is_predict=False):
+ """
+ A Wrapper for sentiment classification task.
+ This network uses bi-directional recurrent network,
+ consisting three LSTM layers. This configure is referred to
+ the paper as following url, but use fewer layrs.
+ http://www.aclweb.org/anthology/P15-1109
+
+ input_dim: here is word dictionary dimension.
+ class_dim: number of categories.
+ emb_dim: dimension of word embedding.
+ hid_dim: dimension of hidden layer.
+ stacked_num: number of stacked lstm-hidden layer.
+ is_predict: is predicting or not.
+ Some layers is not needed in network when predicting.
+ """
+ assert stacked_num % 2 == 1
+
+ layer_attr = ExtraLayerAttribute(drop_rate=0.5)
+ fc_para_attr = ParameterAttribute(learning_rate=1e-3)
+ lstm_para_attr = ParameterAttribute(initial_std=0., learning_rate=1.)
+ para_attr = [fc_para_attr, lstm_para_attr]
+ bias_attr = ParameterAttribute(initial_std=0., l2_rate=0.)
+ relu = ReluActivation()
+ linear = LinearActivation()
+
+ data = data_layer("word", input_dim)
+ emb = embedding_layer(input=data, size=emb_dim)
+
+ fc1 = fc_layer(input=emb, size=hid_dim, act=linear, bias_attr=bias_attr)
+ lstm1 = lstmemory(
+ input=fc1, act=relu, bias_attr=bias_attr, layer_attr=layer_attr)
+
+ inputs = [fc1, lstm1]
+ for i in range(2, stacked_num + 1):
+ fc = fc_layer(
+ input=inputs,
+ size=hid_dim,
+ act=linear,
+ param_attr=para_attr,
+ bias_attr=bias_attr)
+ lstm = lstmemory(
+ input=fc,
+ reverse=(i % 2) == 0,
+ act=relu,
+ bias_attr=bias_attr,
+ layer_attr=layer_attr)
+ inputs = [fc, lstm]
+
+ fc_last = pooling_layer(input=inputs[0], pooling_type=MaxPooling())
+ lstm_last = pooling_layer(input=inputs[1], pooling_type=MaxPooling())
+ output = fc_layer(
+ input=[fc_last, lstm_last],
+ size=class_dim,
+ act=SoftmaxActivation(),
+ bias_attr=bias_attr,
+ param_attr=para_attr)
+
+ if is_predict:
+ outputs(output)
+ else:
+ outputs(classification_cost(input=output, label=data_layer('label', 1)))
+
+
+stacked_lstm_net(
+ dict_dim, class_dim=class_dim, stacked_num=3, is_predict=is_predict)
+# convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict)