Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
f97790a0
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f97790a0
编写于
9月 26, 2019
作者:
L
LiuHao
提交者:
GitHub
9月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
情感分析代码易用性优化 (#3420)
* update * update * update * fix run_ernie.sh * Update README.md
上级
e5c84957
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
919 addition
and
466 deletion
+919
-466
PaddleNLP/models/classification/nets.py
PaddleNLP/models/classification/nets.py
+32
-37
PaddleNLP/preprocess/ernie/task_reader.py
PaddleNLP/preprocess/ernie/task_reader.py
+15
-5
PaddleNLP/preprocess/ernie/tokenization.py
PaddleNLP/preprocess/ernie/tokenization.py
+2
-2
PaddleNLP/sentiment_classification/README.md
PaddleNLP/sentiment_classification/README.md
+246
-83
PaddleNLP/sentiment_classification/config.py
PaddleNLP/sentiment_classification/config.py
+157
-35
PaddleNLP/sentiment_classification/inference_model.py
PaddleNLP/sentiment_classification/inference_model.py
+127
-0
PaddleNLP/sentiment_classification/inference_model_ernie.py
PaddleNLP/sentiment_classification/inference_model_ernie.py
+152
-0
PaddleNLP/sentiment_classification/reader.py
PaddleNLP/sentiment_classification/reader.py
+16
-30
PaddleNLP/sentiment_classification/run.sh
PaddleNLP/sentiment_classification/run.sh
+17
-5
PaddleNLP/sentiment_classification/run_classifier.py
PaddleNLP/sentiment_classification/run_classifier.py
+56
-119
PaddleNLP/sentiment_classification/run_ernie.sh
PaddleNLP/sentiment_classification/run_ernie.sh
+30
-13
PaddleNLP/sentiment_classification/run_ernie_classifier.py
PaddleNLP/sentiment_classification/run_ernie_classifier.py
+51
-110
PaddleNLP/sentiment_classification/senta_config.json
PaddleNLP/sentiment_classification/senta_config.json
+2
-1
PaddleNLP/sentiment_classification/utils.py
PaddleNLP/sentiment_classification/utils.py
+16
-26
未找到文件。
PaddleNLP/models/classification/nets.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide nets for text classification
"""
import
paddle.fluid
as
fluid
def
bow_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
hid_dim2
=
96
,
class_dim
=
2
,
is_
infer
=
False
):
is_
prediction
=
False
):
"""
Bow net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
size
=
[
dict_dim
,
emb_dim
])
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
# bow layer
bow
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'sum'
)
bow_tanh
=
fluid
.
layers
.
tanh
(
bow
)
...
...
@@ -39,7 +27,7 @@ def bow_net(data,
fc_2
=
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
hid_dim2
,
act
=
"tanh"
)
# softmax layer
prediction
=
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
class_dim
,
act
=
"softmax"
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -49,6 +37,7 @@ def bow_net(data,
def
cnn_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
...
...
@@ -56,13 +45,13 @@ def cnn_net(data,
hid_dim2
=
96
,
class_dim
=
2
,
win_size
=
3
,
is_
infer
=
False
):
is_
prediction
=
False
):
"""
Conv net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
size
=
[
dict_dim
,
emb_dim
])
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
# convolution layer
conv_3
=
fluid
.
nets
.
sequence_conv_pool
(
input
=
emb
,
...
...
@@ -75,7 +64,7 @@ def cnn_net(data,
fc_1
=
fluid
.
layers
.
fc
(
input
=
[
conv_3
],
size
=
hid_dim2
)
# softmax layer
prediction
=
fluid
.
layers
.
fc
(
input
=
[
fc_1
],
size
=
class_dim
,
act
=
"softmax"
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -85,6 +74,7 @@ def cnn_net(data,
def
lstm_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
...
...
@@ -92,7 +82,7 @@ def lstm_net(data,
hid_dim2
=
96
,
class_dim
=
2
,
emb_lr
=
30.0
,
is_
infer
=
False
):
is_
prediction
=
False
):
"""
Lstm net
"""
...
...
@@ -101,7 +91,7 @@ def lstm_net(data,
input
=
data
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
emb_lr
))
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
# Lstm layer
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
...
...
@@ -116,7 +106,7 @@ def lstm_net(data,
fc1
=
fluid
.
layers
.
fc
(
input
=
lstm_max_tanh
,
size
=
hid_dim2
,
act
=
'tanh'
)
# softmax layer
prediction
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
class_dim
,
act
=
'softmax'
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -126,6 +116,7 @@ def lstm_net(data,
def
bilstm_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
...
...
@@ -133,7 +124,7 @@ def bilstm_net(data,
hid_dim2
=
96
,
class_dim
=
2
,
emb_lr
=
30.0
,
is_
infer
=
False
):
is_
prediction
=
False
):
"""
Bi-Lstm net
"""
...
...
@@ -143,6 +134,8 @@ def bilstm_net(data,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
emb_lr
))
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
rfc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
lstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
...
...
@@ -161,7 +154,7 @@ def bilstm_net(data,
fc1
=
fluid
.
layers
.
fc
(
input
=
lstm_concat
,
size
=
hid_dim2
,
act
=
'tanh'
)
# softmax layer
prediction
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
class_dim
,
act
=
'softmax'
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -170,6 +163,7 @@ def bilstm_net(data,
def
gru_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
...
...
@@ -177,7 +171,7 @@ def gru_net(data,
hid_dim2
=
96
,
class_dim
=
2
,
emb_lr
=
30.0
,
is_
infer
=
False
):
is_
prediction
=
False
):
"""
gru net
"""
...
...
@@ -185,7 +179,7 @@ def gru_net(data,
input
=
data
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
learning_rate
=
emb_lr
))
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
3
)
gru_h
=
fluid
.
layers
.
dynamic_gru
(
input
=
fc0
,
size
=
hid_dim
,
is_reverse
=
False
)
...
...
@@ -196,7 +190,7 @@ def gru_net(data,
fc1
=
fluid
.
layers
.
fc
(
input
=
gru_max_tanh
,
size
=
hid_dim2
,
act
=
'tanh'
)
prediction
=
fluid
.
layers
.
fc
(
input
=
fc1
,
size
=
class_dim
,
act
=
'softmax'
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
x
=
cost
)
...
...
@@ -206,6 +200,7 @@ def gru_net(data,
def
textcnn_net
(
data
,
seq_len
,
label
,
dict_dim
,
emb_dim
=
128
,
...
...
@@ -213,7 +208,7 @@ def textcnn_net(data,
hid_dim2
=
96
,
class_dim
=
2
,
win_sizes
=
None
,
is_infer
=
False
):
is_prediction
=
False
):
"""
Textcnn_net
"""
...
...
@@ -222,7 +217,7 @@ def textcnn_net(data,
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
size
=
[
dict_dim
,
emb_dim
])
emb
=
fluid
.
layers
.
sequence_unpad
(
emb
,
length
=
seq_len
)
# convolution layer
convs
=
[]
for
win_size
in
win_sizes
:
...
...
@@ -239,7 +234,7 @@ def textcnn_net(data,
fc_1
=
fluid
.
layers
.
fc
(
input
=
[
convs_out
],
size
=
hid_dim2
,
act
=
"tanh"
)
# softmax layer
prediction
=
fluid
.
layers
.
fc
(
input
=
[
fc_1
],
size
=
class_dim
,
act
=
"softmax"
)
if
is_
infer
:
if
is_
prediction
:
return
prediction
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
...
...
PaddleNLP/preprocess/ernie/task_reader.py
浏览文件 @
f97790a0
...
...
@@ -27,7 +27,17 @@ import numpy as np
from
preprocess.ernie
import
tokenization
from
preprocess.padding
import
pad_batch_data
import
io
def
csv_reader
(
fd
,
delimiter
=
'
\t
'
):
def
gen
():
for
i
in
fd
:
slots
=
i
.
rstrip
(
'
\n
'
).
split
(
delimiter
)
if
len
(
slots
)
==
1
:
yield
slots
,
else
:
yield
slots
return
gen
()
class
BaseReader
(
object
):
"""BaseReader for classify and sequence labeling task"""
...
...
@@ -66,8 +76,8 @@ class BaseReader(object):
def
_read_tsv
(
self
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
with
io
.
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
reader
=
csv
_reader
(
f
,
delimiter
=
"
\t
"
)
headers
=
next
(
reader
)
Example
=
namedtuple
(
'Example'
,
headers
)
...
...
@@ -228,8 +238,8 @@ class ClassifyReader(BaseReader):
def
_read_tsv
(
self
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
with
io
.
open
(
input_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
reader
=
csv
_reader
(
f
,
delimiter
=
"
\t
"
)
headers
=
next
(
reader
)
text_indices
=
[
index
for
index
,
h
in
enumerate
(
headers
)
if
h
!=
"label"
...
...
PaddleNLP/preprocess/ernie/tokenization.py
浏览文件 @
f97790a0
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
collections
import
unicodedata
import
six
import
io
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
...
...
@@ -69,7 +69,7 @@ def printable_text(text):
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
fin
=
open
(
vocab_file
,
encoding
=
"utf8"
)
fin
=
io
.
open
(
vocab_file
,
encoding
=
"utf8"
)
for
num
,
line
in
enumerate
(
fin
):
items
=
convert_to_unicode
(
line
.
strip
()).
split
(
"
\t
"
)
if
len
(
items
)
>
2
:
...
...
PaddleNLP/sentiment_classification/README.md
浏览文件 @
f97790a0
#
# 简介
#
情感倾向分析
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。可通过
[
AI开放平台-情感倾向分析
](
http://ai.baidu.com/tech/nlp_apply/sentiment_classify
)
线上体验。
*
[
模型简介
](
#模型简介
)
*
[
快速开始
](
#快速开始
)
*
[
进阶使用
](
#进阶使用
)
*
[
版本更新
](
#版本更新
)
*
[
作者
](
#作者
)
*
[
如何贡献代码
](
#如何贡献代码
)
## 模型简介
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。可通过
[
AI开放平台-情感倾向分析
](
http://ai.baidu.com/tech/nlp_apply/sentiment_classify
)
线上体验。
情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测;此外,我们还开源了百度基于海量数据训练好的模型,该模型在ChnSentiCorp数据集上fine-tune之后(基于开源模型进行Finetune的方法请见下面章节),可以得到更好的效果。具体数据如下所示:
...
...
@@ -14,49 +23,113 @@
| ERNIE | 95.1% | 95.4% | ERNIE |95.4% | 95.5% |
| ERNIE+BI-LSTM | 95.3% | 95.2% | ERNIE+BI-LSTM |95.7% | 95.6% |
## 快速开始
### 安装说明
## 快速开始
1.
PaddlePaddle 安装
本项目依赖于 PaddlePaddle Fluid 1.3.2 及以上版本,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装
2.
代码安装
克隆代码库到本地
```
shell
git clone https://github.com/PaddlePaddle/models.git
cd
models/PaddleNLP/sentiment_classification
```
3.
环境依赖
请参考 PaddlePaddle
[
安装说明
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html
)
部分的内容
### 代码结构说明
以下是本项目主要代码结构及说明:
```
text
.
├── senta_config.json # 配置文件
├── config.py # 配置文件读取接口
├── inference_model.py # 保存 inference_model 的脚本
├── inference_ernie_model.py # 保存 inference_ernie__model 的脚本
├── reader.py # 数据读取接口
├── run_classifier.py # 项目的主程序入口,包括训练、预测、评估
├── run.sh # 训练、预测、评估运行脚本
├── run_ernie_classifier.py # 基于ERNIE表示的项目的主程序入口
├── run_ernie.sh # 基于ERNIE的训练、预测、评估运行脚本
├── utils.py # 其它功能函数脚本
```
### 数据准备
本项目依赖于 Paddlepaddle 1.3.2 及以上版本,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装
#### **自定义数据**
python版本依赖python 2.7
训练、预测、评估使用的数据可以由用户根据实际的应用场景,自己组织数据。数据由两列组成,以制表符分隔,第一列是以空格分词的中文文本(分词预处理方法将在下文具体说明),文件为utf8编码;第二列是情感倾向分类的类别(0表示消极;1表示积极),注意数据文件第一行固定表示为"text_a
\t
label"
注意:该模型同时支持cpu和gpu训练和预测,用户可以根据自身需求,选择安装对应的paddlepaddle-gpu或paddlepaddle版本。
```
text
特 喜欢 这种 好看的 狗狗 1
这 真是 惊艳 世界 的 中国 黑科技 1
环境 特别 差 ,脏兮兮 的,再也 不去 了 0
```
#### 安装代码
注:PaddleNLP 项目提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如下:
克隆数据集代码库到本地
```
shell
git clone https://github.com/PaddlePaddle/models.git
cd
models/PaddleNLP/sentiment_classification
python tokenizer.py
--test_data_dir
./test.txt.utf8
--batch_size
1
>
test.txt.utf8.seg
#其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中。
```
####
数据准备
####
公开数据集
下载经过预处理的数据,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt)
```
shell
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz
tar
-zxvf
sentiment_classification-dataset-1.0.0.tar.gz
```
#### 模型下载
```
text
.
├── train.tsv # 训练集
├── train.tsv # 验证集
├── test.tsv # 测试集
├── word_dict.txt # 词典
```
我们开源了基于ChnSentiCorp数据训练的情感倾向性分类模型(基于BOW、CNN、LSTM、ERNIE多种模型训练),可供用户直接使用。我们提供了两种下载方式:
### 单机训练
方式一:基于PaddleHub命令行工具(PaddleHub安装方式 https://github.com/PaddlePaddle/PaddleHub )
基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
```
shell
hub download sentiment_classification
--output_path
./
tar
-zxvf
sentiment_classification-1.0.0.tar.gz
# BOW、CNN、LSTM、BI-LSTM、GRU模型
sh run.sh train
# ERNIE、ERNIE+BI-LSTM模型
sh run_ernie.sh train
```
训练完成后,可修改
```run.sh```
中init_checkpoint参数,进行模型评估和预测
方式二:直接下载
```
shell
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-1.0.0.tar.gz
tar
-zxvf
sentiment_classification-1.0.0.tar.gz
```
#### 模型评估
"""
# 输出结果示例
Running type options:
--do_train DO_TRAIN Whether to perform training. Default: False.
...
Model config options:
--model_type {bow_net,cnn_net,lstm_net,bilstm_net,gru_net,textcnn_net}
Model type to run the task. Default: textcnn_net.
--init_checkpoint INIT_CHECKPOINT
Init checkpoint to resume training from. Default: .
--save_checkpoint_dir SAVE_CHECKPOINT_DIR
Directory path to save checkpoints Default: .
...
"""
```
本项目参数控制优先级:命令行参数 > ```config.json ``` > 默认值。训练完成后,会在```./save_models``` 目录下生成以 ```step_xxx ``` 命名的模型目录。
### 模型评估
基于上面的预训练模型和数据,可以运行下面的命令进行测试,查看预训练模型在开发集(dev.tsv)上的评测效果
```
shell
...
...
@@ -83,46 +156,76 @@ senta_config.json中需要修改如下:
--init_checkpoint senta_model/ernie_bilstm_model/
--model_type "ernie_bilstm"
```
```
"""
# 输出结果示例
Load model from ./save_models/step_100
Final test result:
[test evaluation] avg loss: 0.339021, avg acc: 0.869691, elapsed time: 0.123983 s
"""
```
我们也提供了使用PaddleHub加载ERNIE模型的选项,PaddleHub是PaddlePaddle的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看
[
PaddleHub
](
https://github.com/PaddlePaddle/PaddleHub
)
### 模型推断
如果想使用该功能,需要修改run_ernie.sh中的配置如下:
利用已有模型,可以运行下面命令,对未知label的数据(test.tsv)进行预测
```
shell
# 在eval()函数中,修改如下参数:
--use_paddle_hub
true
# BOW、CNN、LSTM、BI-LSTM、GRU模型
sh run.sh infer
#ERNIE+BI-LSTM模型
sh run_ernie.sh infer
```
注意:使用该选项需要先安装PaddleHub,安装命令如下
```
"""
# 输出结果示例
Load model from ./save_models/step_100
1 0.001659 0.998341
0 0.987223 0.012777
1 0.001365 0.998635
1 0.001875 0.998125
"""
```
### 预训练模型
我们开源了基于海量数据训练好的情感倾向分类模型(基于CNN、BI-LSTM、ERNIE等模型训练),可供用户直接使用,我们提供两种下载方式。
**方式一**
:基于PaddleHub命令行工具(PaddleHub
[
安装方式
](
https://github.com/PaddlePaddle/PaddleHub
)
)
```
shell
$
pip
install
paddlehub
hub download sentiment_classification
--output_path
./
tar
-zxvf
sentiment_classification-1.0.0.tar.gz
```
#### 模型训练
**方式二**
:直接下载脚本
基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
```
shell
# BOW、CNN、LSTM、BI-LSTM、GRU模型
sh run.sh train
# ERNIE、ERNIE+BI-LSTM模型
sh run_ernie.sh train
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-1.0.0.tar.gz
tar
-zxvf
sentiment_classification-1.0.0.tar.gz
```
训练完成后,可修改
```run.sh```
中init_checkpoint参数,进行模型评估和预测
#### 模型预测
以上两种方式会将预训练的 CNN、BI-LSTM等模型和 ERNIE模型,保存在当前目录下,可直接修改
```run.sh```
脚本中的
```init_checkpoint```
参数进行评估、预测。
### 服务部署
为了将模型应用于线上部署,可以利用
```inference_model.py```
、
```inference_ernie_model.py```
脚本对模型进行裁剪,只保存网络参数及裁剪后的模型。运行命令如下:
利用已有模型,可以运行下面命令,对未知label的数据(test.tsv)进行预测
```
shell
# BOW、CNN、LSTM、BI-LSTM、GRU模型
sh run.sh infer
#ERNIE+BI-LSTM模型
sh run_ernie.sh infer
sh run.sh save_inference_model
sh run_ernie.sh save_inference_model
```
#### 服务器部署
请参考PaddlePaddle官方提供的
[
服务器端部署
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/deploy/inference/index_cn.html
)
文档进行部署上线。
## 进阶使用
###
# 任务定义
###
背景介绍
传统的情感分类主要基于词典或者特征工程的方式进行分类,这种方法需要繁琐的人工特征设计和先验知识,理解停留于浅层并且扩展泛化能力差。为了避免传统方法的局限,我们采用近年来飞速发展的深度学习技术。基于深度学习的情感分类不依赖于人工特征,它能够端到端的对输入文本进行语义理解,并基于语义表示进行情感倾向的判断。
#### 模型原理介绍
### 模型概览
本项目针对情感倾向性分类问题,开源了一系列模型,供用户可配置地使用:
...
...
@@ -134,66 +237,126 @@ sh run_ernie.sh infer
+
ERNIE(Enhanced Representation through kNowledge IntEgration),百度自研基于海量数据和先验知识训练的通用文本语义表示模型,并基于此在情感倾向分类数据集上进行fine-tune获得。
+
ERNIE+BI-LSTM,基于ERNIE语义表示对接上层BI-LSTM模型,并基于此在情感倾向分类数据集上进行Fine-tune获得;
###
# 数据格式说明
###
自定义模型
训练、预测、评估使用的数据可以由用户根据实际的应用场景,自己组织数据。数据由两列组成,以制表符分隔,第一列是以空格分词的中文文本(分词预处理方法将在下文具体说明),文件为utf8编码;第二列是情感倾向分类的类别(0表示消极;1表示积极),注意数据文件第一行固定表示为"text_a
\t
label"
可以根据自己的需求,组建自定义的模型,具体方法如下所示:
```
text
特 喜欢 这种 好看的 狗狗 1
这 真是 惊艳 世界 的 中国 黑科技 1
环境 特别 差 ,脏兮兮 的,再也 不去 了 0
1.
定义自己的网络结构
用户可以在
```models/classification/nets.py```
中,定义自己的模型,只需要增加新的函数即可。假设用户自定义的函数名为
```user_net```
2.
更改模型配置
在
```senta_config.json```
中需要将
```model_type```
改为用户自定义的
```user_net```
3.
模型训练
通过
```run.sh```
脚本运行训练、评估、预测。
### 基于 ERNIE 进行 Finetune
ERNIE 是百度自研的基于海量数据和先验知识训练的通用文本语义表示模型,基于 ERNIE 进行 Finetune,能够提升对话情绪识别的效果。
#### 模型训练
需要先下载 ERNIE 模型,使用如下命令:
```
shell
mkdir
-p
pretrain_models/ernie
cd
pretrain_models/ernie
wget
--no-check-certificate
https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz
-O
ERNIE_stable-1.0.1.tar.gz
tar
-zxvf
ERNIE_stable-1.0.1.tar.gz
```
注:本项目额外提供了分词预处理脚本(在本项目的preprocess目录下),可供用户使用,具体使用方法如下:
然后修改
```run_ernie.sh```
脚本中train 函数的
```init_checkpoint```
参数,再执行命令:
```
shell
python tokenizer.py
--test_data_dir
./test.txt.utf8
--batch_size
1
>
test.txt.utf8.seg
#--init_checkpoint ./pretrain_models/ernie
sh run_ernie.sh train
```
#其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中。
默认使用GPU进行训练,模型保存在
```./save_models/ernie/```
目录下,以
```step_xxx ```
命名。
#### 模型评估
根据训练结果,可选择最优的step进行评估,修改
```run_ernie.sh```
脚本中 eval 函数
```init_checkpoint```
参数,然后执行
```
shell
#--init_checkpoint./save/step_907
sh run_ernie.sh
eval
'''
# 输出结果示例
W0820 14:59:47.811139 334 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0820 14:59:47.815557 334 device_context.cc:267] device: 0, cuDNN Version: 7.3.
Load model from ./save_models/ernie/step_907
Final validation result:
[test evaluation] avg loss: 0.260597, ave acc: 0.907336, elapsed time: 2.383077 s
'''
```
####
代码结构说明
####
模型推断
```
text
.
├── senta_config.json # 模型配置文件
├── config.py # 定义了该项目模型的相关配置,包括具体模型类别、以及模型的超参数
├── reader.py # 定义了读入数据,加载词典的功能
├── run_classifier.py # 该项目的主函数,封装包括训练、预测、评估的部分
├── run_ernie_classifier.py # 基于ERNIE表示的项目的主函数
├── run_ernie.sh # 基于ERNIE的训练、预测、评估运行脚本
├── run.sh # 训练、预测、评估运行脚本
├── utils.py # 定义了其他常用的功能函数
修改
```run_ernie.sh```
脚本中 infer 函数
```init_checkpoint```
参数,然后执行
```
shell
#--init_checkpoint./save/step_907
sh run_ernie.sh infer
'''
# 输出结果示例
Load model from ./save_models/ernie/step_907
Final test result:
1 0.001130 0.998870
0 0.978465 0.021535
1 0.000847 0.999153
1 0.001498 0.998502
'''
```
###
# 如何组建自己的模型
###
基于 PaddleHub 加载 ERNIE 进行 Finetune
可以根据自己的需求,组建自定义的模型,具体方法如下所示:
我们也提供了使用 PaddleHub 加载 ERNIE 模型的选项,PaddleHub 是 PaddlePaddle 的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看
[
PaddleHub
](
https://github.com/PaddlePaddle/PaddleHub
)
1.
定义自己的网络结构
用户可以在
```models/classification/nets.py```
中,定义自己的模型,只需要增加新的函数即可。假设用户自定义的函数名为
```user_net```
2.
更改模型配置
在
```senta_config.json```
中需要将
```model_type```
改为用户自定义的
```user_net```
3.
模型训练、评估、预测需要在 run.sh 、run_ernie.sh 中将模型、数据、词典路径等配置进行修改
注意:使用该选项需要先安装PaddleHub,安装命令如下
```
shell
pip
install
paddlehub
```
需要修改
```run_ernie.sh```
中的配置如下:
#### 如何基于百度开源模型进行Finetune
用户可基于百度开源模型在自有数据上实现Finetune训练,以期获得更好的效果提升;如『简介』部分中,我们基于百度开源模型在ChnSentiCorp数据集上Finetune后可以得到更好的效果,具体模型Finetune方法如下所示,如果用户基于开源BI-LSTM模型进行Finetune,需要修改run.sh和senta_config.json文件;
run.sh脚本修改如下:
```
shell
# 在train()函数中,增加--init_checkpoint选项;修改--vocab_path
--init_checkpoint
senta_model/bilstm_model/params
--vocab_path
senta_model/bilstm_model/word_dict.txt
# 在train()函数中,修改--use_paddle_hub选项
--use_paddle_hub
true
```
senta_config.json中需要修改如下:
执行以下命令进行 Finetune
```
shell
# vob_size大小对应为上面senta_model/bilstm_model//word_dict.txt,词典大小
"vocab_size"
: 1256606
sh run_ernie.sh train
```
如果用户基于开源ERNIE+BI-LSTM模型进行Finetune,需要更新run_ernie.sh脚本,具体修改如下:
Finetune 结束后,进行 eval 或者 infer 时,需要修改
```run_ernie.sh```
中的配置如下:
```
shell
# 在train()函数中,修改--init_checkpoint选项;修改--model_type
--init_checkpoint
senta_model/ernie_bilstm_model
--model_type
"ernie_bilstm"
# 在eval()和infer()函数中,修改--use_paddle_hub选项
--use_paddle_hub
true
```
执行以下命令进行 eval 和 infer
```
shell
sh run_ernie.sh
eval
sh run_ernie.sh infer
```
## 版本更新
2019/08/26 规范化配置的使用,对模块内数据处理代码进行了重构,更新README结构,提高易用性。
2019/06/13 添加PaddleHub调用ERNIE方式。
## 作者
-
[
liuhao
](
https://github.com/ChinaLiuHao
)
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
PaddleNLP/sentiment_classification/config.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Senta
model
.
Senta
config
.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
six
import
json
import
numpy
as
np
import
paddle.fluid
as
fluid
import
argparse
def
str2bool
(
value
):
"""
String to Boolean
"""
# because argparse does not support to parse "True, False" as python
# boolean directly
return
value
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
class
ArgumentGroup
(
object
):
"""
Argument Class
"""
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
dtype
,
default
,
help
,
**
kwargs
):
"""
Add argument
"""
dtype
=
str2bool
if
dtype
==
bool
else
dtype
self
.
_group
.
add_argument
(
"--"
+
name
,
default
=
default
,
type
=
dtype
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
class
Senta
Config
(
object
):
class
PD
Config
(
object
):
"""
Senta Config
A high-level api for handling argument configs.
"""
def
__init__
(
self
,
json_file
=
""
):
"""
Init function for PDConfig.
json_file: the path to the json configure file.
"""
assert
isinstance
(
json_file
,
str
)
self
.
args
=
None
self
.
arg_config
=
{}
parser
=
argparse
.
ArgumentParser
()
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"ernie_config_path"
,
str
,
None
,
"Path to the json file for ernie model config."
)
model_g
.
add_arg
(
"senta_config_path"
,
str
,
None
,
"Path to the json file for senta model config."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"checkpoints"
,
str
,
"checkpoints"
,
"Path to save checkpoints"
)
model_g
.
add_arg
(
"model_type"
,
str
,
"ernie_base"
,
"Type of current ernie model"
)
model_g
.
add_arg
(
"use_paddle_hub"
,
bool
,
False
,
"Whether to load ERNIE using PaddleHub"
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
10
,
"Number of epoches for training."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"lr"
,
float
,
0.002
,
"The Learning rate value for training."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related"
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log"
)
log_g
.
add_arg
(
'enable_ce'
,
bool
,
False
,
'If set, run the task with continuous evaluation logs.'
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
.
add_arg
(
"data_dir"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
None
,
"Vocabulary path."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
256
,
"Total examples' number in batch for training."
)
data_g
.
add_arg
(
"random_seed"
,
int
,
0
,
"Random seed."
)
data_g
.
add_arg
(
"num_labels"
,
int
,
2
,
"label number"
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
512
,
"Number of words of the longest sequence."
)
data_g
.
add_arg
(
"train_set"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"test_set"
,
str
,
None
,
"Path to test data."
)
data_g
.
add_arg
(
"dev_set"
,
str
,
None
,
"Path to validation data."
)
data_g
.
add_arg
(
"label_map_config"
,
str
,
None
,
"label_map_path."
)
data_g
.
add_arg
(
"do_lower_case"
,
bool
,
True
,
"Whether to lower case the input text. Should be True for uncased models and False for cased models"
)
def
__init__
(
self
,
config_path
):
self
.
_config_dict
=
self
.
_parse
(
config_path
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"task_name"
,
str
,
None
,
"The name of task to perform sentiment classification."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_val"
,
bool
,
True
,
"Whether to perform evaluation."
)
run_type_g
.
add_arg
(
"do_infer"
,
bool
,
True
,
"Whether to perform inference."
)
run_type_g
.
add_arg
(
"do_save_inference_model"
,
bool
,
True
,
"Whether to save inference model"
)
run_type_g
.
add_arg
(
"inference_model_dir"
,
str
,
None
,
"Path to save inference model"
)
def
_parse
(
self
,
config_path
):
custom_g
=
ArgumentGroup
(
parser
,
"Customize options"
,
""
)
self
.
custom_g
=
custom_g
self
.
parser
=
parser
self
.
arglist
=
[
a
.
dest
for
a
in
self
.
parser
.
_actions
]
self
.
json_config
=
None
if
json_file
!=
""
:
self
.
load_json
(
json_file
)
def
load_json
(
self
,
file_path
):
"""load json config """
if
not
os
.
path
.
exists
(
file_path
):
raise
Warning
(
"the json file %s does not exist."
%
file_path
)
return
try
:
with
open
(
config_path
)
as
json_file
:
config_dict
=
json
.
load
(
json_file
)
except
Exception
:
raise
IOError
(
"Error in parsing bert model config file '%s'"
%
config_path
)
with
open
(
file_path
,
"r"
)
as
fin
:
self
.
json_config
=
json
.
load
(
fin
)
except
Exception
as
e
:
raise
IOError
(
"Error in parsing json config file '%s'"
%
file_path
)
for
name
in
self
.
json_config
:
# use `six.string_types` but not `str` for compatible with python2 and python3
if
not
isinstance
(
self
.
json_config
[
name
],
(
int
,
float
,
bool
,
six
.
string_types
)):
continue
if
name
in
self
.
arglist
:
self
.
set_default
(
name
,
self
.
json_config
[
name
])
else
:
return
config_dict
self
.
custom_g
.
add_arg
(
name
,
type
(
self
.
json_config
[
name
]),
self
.
json_config
[
name
],
"customized options"
)
def
__getitem__
(
self
,
key
):
return
self
.
_config_dict
[
key
]
def
set_default
(
self
,
name
,
value
):
for
arg
in
self
.
parser
.
_actions
:
if
arg
.
dest
==
name
:
arg
.
default
=
value
def
print_config
(
self
):
"""
Print Config
"""
for
arg
,
value
in
sorted
(
six
.
iteritems
(
self
.
_config_dict
)):
def
build
(
self
):
self
.
args
=
self
.
parser
.
parse_args
()
self
.
arg_config
=
vars
(
self
.
args
)
def
print_arguments
(
self
):
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
self
.
arg_config
)):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
def
add_arg
(
self
,
name
,
dtype
,
default
,
descrip
):
self
.
custom_g
.
add_arg
(
name
,
dtype
,
default
,
descrip
)
def
__add__
(
self
,
new_arg
):
assert
isinstance
(
new_arg
,
list
)
or
isinstance
(
new_arg
,
tuple
)
assert
len
(
new_arg
)
>=
3
assert
self
.
args
is
None
name
=
new_arg
[
0
]
dtype
=
new_arg
[
1
]
dvalue
=
new_arg
[
2
]
desc
=
new_arg
[
3
]
if
len
(
new_arg
)
==
4
else
"Description is not provided."
self
.
add_arg
(
name
,
dtype
,
dvalue
,
desc
)
return
self
def
__getattr__
(
self
,
name
):
if
name
in
self
.
arg_config
:
return
self
.
arg_config
[
name
]
if
name
in
self
.
json_config
:
return
self
.
json_config
[
name
]
raise
Warning
(
"The argument %s is not defined."
%
name
)
if
__name__
==
'__main__'
:
pd_config
=
PDConfig
(
'senta_config.json'
)
pd_config
.
add_arg
(
"my_age"
,
int
,
18
,
"I am forever 18."
)
pd_config
.
build
()
pd_config
.
print_arguments
()
print
(
pd_config
.
use_cuda
)
print
(
pd_config
.
model_type
)
PaddleNLP/sentiment_classification/inference_model.py
0 → 100644
浏览文件 @
f97790a0
# -*- coding: utf_8 -*-
import
os
import
sys
sys
.
path
.
append
(
"../"
)
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
models.model_check
import
check_cuda
from
config
import
PDConfig
from
run_classifier
import
create_model
import
utils
import
reader
def
do_save_inference_model
(
args
):
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_pyreader
,
probs
,
feed_target_names
=
create_model
(
args
,
pyreader_name
=
'infer_reader'
,
num_labels
=
args
.
num_labels
,
is_prediction
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
exe
.
run
(
startup_prog
)
assert
(
args
.
init_checkpoint
)
if
args
.
init_checkpoint
:
utils
.
init_checkpoint
(
exe
,
args
.
init_checkpoint
,
test_prog
)
fluid
.
io
.
save_inference_model
(
args
.
inference_model_dir
,
feeded_var_names
=
feed_target_names
,
target_vars
=
[
probs
],
executor
=
exe
,
main_program
=
test_prog
,
model_filename
=
"model.pdmodel"
,
params_filename
=
"params.pdparams"
)
print
(
"save inference model at %s"
%
(
args
.
inference_model_dir
))
def
inference
(
exe
,
test_program
,
test_pyreader
,
fetch_list
,
infer_phrase
):
"""
Inference Function
"""
print
(
"================="
)
test_pyreader
.
start
()
while
True
:
try
:
np_props
=
exe
.
run
(
program
=
test_program
,
fetch_list
=
fetch_list
,
return_numpy
=
True
)
for
probs
in
np_props
[
0
]:
print
(
"%d
\t
%f
\t
%f"
%
(
np
.
argmax
(
probs
),
probs
[
0
],
probs
[
1
]))
except
fluid
.
core
.
EOFException
:
test_pyreader
.
reset
()
break
def
test_inference_model
(
args
):
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_pyreader
,
probs
,
feed_target_names
=
create_model
(
args
,
pyreader_name
=
'infer_reader'
,
num_labels
=
args
.
num_labels
,
is_prediction
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
processor
=
reader
.
SentaProcessor
(
data_dir
=
args
.
data_dir
,
vocab_path
=
args
.
vocab_path
,
random_seed
=
args
.
random_seed
,
max_seq_len
=
args
.
max_seq_len
)
num_labels
=
len
(
processor
.
get_labels
())
assert
(
args
.
inference_model_dir
)
infer_program
,
feed_names
,
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
args
.
inference_model_dir
,
executor
=
exe
,
model_filename
=
"model.pdmodel"
,
params_filename
=
"params.pdparams"
)
infer_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
"infer"
,
epoch
=
1
,
shuffle
=
False
)
infer_pyreader
.
decorate_sample_list_generator
(
infer_data_generator
)
inference
(
exe
,
test_prog
,
infer_pyreader
,
[
probs
.
name
],
"infer"
)
if
__name__
==
"__main__"
:
args
=
PDConfig
(
'senta_config.json'
)
args
.
build
()
args
.
print_arguments
()
check_cuda
(
args
.
use_cuda
)
if
args
.
do_save_inference_model
:
do_save_inference_model
(
args
)
else
:
test_inference_model
(
args
)
PaddleNLP/sentiment_classification/inference_model_ernie.py
0 → 100644
浏览文件 @
f97790a0
# -*- coding: utf_8 -*-
import
os
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../models/classification"
)
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
models.model_check
import
check_cuda
from
config
import
PDConfig
from
run_ernie_classifier
import
create_model
import
utils
import
reader
from
run_ernie_classifier
import
ernie_pyreader
from
models.representation.ernie
import
ErnieConfig
from
models.representation.ernie
import
ernie_encoder
from
preprocess.ernie
import
task_reader
def
do_save_inference_model
(
args
):
ernie_config
=
ErnieConfig
(
args
.
ernie_config_path
)
ernie_config
.
print_config
()
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_pyreader
,
ernie_inputs
,
labels
=
ernie_pyreader
(
args
,
pyreader_name
=
"infer_reader"
)
embeddings
=
ernie_encoder
(
ernie_inputs
,
ernie_config
=
ernie_config
)
probs
=
create_model
(
args
,
embeddings
,
labels
=
labels
,
is_prediction
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
exe
.
run
(
startup_prog
)
assert
(
args
.
init_checkpoint
)
if
args
.
init_checkpoint
:
utils
.
init_checkpoint
(
exe
,
args
.
init_checkpoint
,
test_prog
)
fluid
.
io
.
save_inference_model
(
args
.
inference_model_dir
,
feeded_var_names
=
[
ernie_inputs
[
"src_ids"
].
name
,
ernie_inputs
[
"sent_ids"
].
name
,
ernie_inputs
[
"pos_ids"
].
name
,
ernie_inputs
[
"input_mask"
].
name
,
ernie_inputs
[
"seq_lens"
].
name
],
target_vars
=
[
probs
],
executor
=
exe
,
main_program
=
test_prog
,
model_filename
=
"model.pdmodel"
,
params_filename
=
"params.pdparams"
)
print
(
"save inference model at %s"
%
(
args
.
inference_model_dir
))
def
inference
(
exe
,
test_program
,
test_pyreader
,
fetch_list
,
infer_phrase
):
"""
Inference Function
"""
print
(
"================="
)
test_pyreader
.
start
()
while
True
:
try
:
np_props
=
exe
.
run
(
program
=
test_program
,
fetch_list
=
fetch_list
,
return_numpy
=
True
)
for
probs
in
np_props
[
0
]:
print
(
"%d
\t
%f
\t
%f"
%
(
np
.
argmax
(
probs
),
probs
[
0
],
probs
[
1
]))
except
fluid
.
core
.
EOFException
:
test_pyreader
.
reset
()
break
def
test_inference_model
(
args
):
ernie_config
=
ErnieConfig
(
args
.
ernie_config_path
)
ernie_config
.
print_config
()
if
args
.
use_cuda
:
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
place
=
fluid
.
CUDAPlace
(
0
)
else
:
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
reader
=
task_reader
.
ClassifyReader
(
vocab_path
=
args
.
vocab_path
,
label_map_config
=
args
.
label_map_config
,
max_seq_len
=
args
.
max_seq_len
,
do_lower_case
=
args
.
do_lower_case
,
random_seed
=
args
.
random_seed
)
test_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_pyreader
,
ernie_inputs
,
labels
=
ernie_pyreader
(
args
,
pyreader_name
=
"infer_pyreader"
)
embeddings
=
ernie_encoder
(
ernie_inputs
,
ernie_config
=
ernie_config
)
probs
=
create_model
(
args
,
embeddings
,
labels
=
labels
,
is_prediction
=
True
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
exe
.
run
(
startup_prog
)
assert
(
args
.
inference_model_dir
)
infer_data_generator
=
reader
.
data_generator
(
input_file
=
args
.
test_set
,
batch_size
=
args
.
batch_size
,
phase
=
"infer"
,
epoch
=
1
,
shuffle
=
False
)
infer_program
,
feed_names
,
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
args
.
inference_model_dir
,
executor
=
exe
,
model_filename
=
"model.pdmodel"
,
params_filename
=
"params.pdparams"
)
infer_pyreader
.
decorate_batch_generator
(
infer_data_generator
)
inference
(
exe
,
test_prog
,
infer_pyreader
,
[
probs
.
name
],
"infer"
)
if
__name__
==
"__main__"
:
args
=
PDConfig
()
args
.
build
()
args
.
print_arguments
()
check_cuda
(
args
.
use_cuda
)
if
args
.
do_save_inference_model
:
do_save_inference_model
(
args
)
else
:
test_inference_model
(
args
)
PaddleNLP/sentiment_classification/reader.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Senta Reader
"""
...
...
@@ -25,38 +12,39 @@ from utils import data_reader
import
paddle
import
paddle.fluid
as
fluid
class
SentaProcessor
(
object
):
"""
Processor class for data convertors for senta
"""
def
__init__
(
self
,
data_dir
,
vocab_path
,
random_seed
=
None
):
def
__init__
(
self
,
data_dir
,
vocab_path
,
random_seed
,
max_seq_len
):
self
.
data_dir
=
data_dir
self
.
vocab
=
load_vocab
(
vocab_path
)
self
.
num_examples
=
{
"train"
:
-
1
,
"dev"
:
-
1
,
"infer"
:
-
1
}
np
.
random
.
seed
(
random_seed
)
self
.
max_seq_len
=
max_seq_len
def
get_train_examples
(
self
,
data_dir
,
epoch
):
def
get_train_examples
(
self
,
data_dir
,
epoch
,
max_seq_len
):
"""
Load training examples
"""
return
data_reader
((
self
.
data_dir
+
"/train.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"train"
,
epoch
)
return
data_reader
((
self
.
data_dir
+
"/train.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"train"
,
epoch
,
max_seq_len
)
def
get_dev_examples
(
self
,
data_dir
,
epoch
):
def
get_dev_examples
(
self
,
data_dir
,
epoch
,
max_seq_len
):
"""
Load dev examples
"""
return
data_reader
((
self
.
data_dir
+
"/dev.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"dev"
,
epoch
)
return
data_reader
((
self
.
data_dir
+
"/dev.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"dev"
,
epoch
,
max_seq_len
)
def
get_test_examples
(
self
,
data_dir
,
epoch
):
def
get_test_examples
(
self
,
data_dir
,
epoch
,
max_seq_len
):
"""
Load test examples
"""
return
data_reader
((
self
.
data_dir
+
"/test.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"infer"
,
epoch
)
return
data_reader
((
self
.
data_dir
+
"/test.tsv"
),
self
.
vocab
,
self
.
num_examples
,
"infer"
,
epoch
,
max_seq_len
)
def
get_labels
(
self
):
"""
...
...
@@ -84,14 +72,12 @@ class SentaProcessor(object):
Generate data for train, dev or infer
"""
if
phase
==
"train"
:
return
paddle
.
batch
(
self
.
get_train_examples
(
self
.
data_dir
,
epoch
),
batch_size
)
return
paddle
.
batch
(
self
.
get_train_examples
(
self
.
data_dir
,
epoch
,
self
.
max_seq_len
),
batch_size
)
#return self.get_train_examples(self.data_dir, epoch, self.max_seq_len
)
elif
phase
==
"dev"
:
return
paddle
.
batch
(
self
.
get_dev_examples
(
self
.
data_dir
,
epoch
),
batch_size
)
return
paddle
.
batch
(
self
.
get_dev_examples
(
self
.
data_dir
,
epoch
,
self
.
max_seq_len
),
batch_size
)
elif
phase
==
"infer"
:
return
paddle
.
batch
(
self
.
get_test_examples
(
self
.
data_dir
,
epoch
),
batch_size
)
return
paddle
.
batch
(
self
.
get_test_examples
(
self
.
data_dir
,
epoch
,
self
.
max_seq_len
),
batch_size
)
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'infer']."
)
PaddleNLP/sentiment_classification/run.sh
浏览文件 @
f97790a0
#! /bin/bash
export
FLAGS_enable_parallel_graph
=
1
export
FLAGS_sync_nccl_allreduce
=
1
export
CUDA_VISIBLE_DEVICES
=
1
export
CUDA_VISIBLE_DEVICES
=
1
2
export
FLAGS_fraction_of_gpu_memory_to_use
=
0.95
export
CPU_NUM
=
1
...
...
@@ -16,9 +16,9 @@ train() {
--task_name
${
TASK_NAME
}
\
--use_cuda
true
\
--do_train
true
\
--do_val
tru
e
\
--do_val
fals
e
\
--do_infer
false
\
--batch_size
16
\
--batch_size
8
\
--data_dir
${
DATA_PATH
}
\
--vocab_path
${
DATA_PATH
}
/word_dict.txt
\
--checkpoints
${
CKPT_PATH
}
\
...
...
@@ -59,6 +59,15 @@ infer() {
--senta_config_path
./senta_config.json
}
# run_save_inference_model
save_inference_model
()
{
python
-u
inference_model.py
\
--use_cuda
false
\
--do_save_inference_model
true
\
--init_checkpoint
${
MODEL_PATH
}
\
--inference_model_dir
./inference_model
}
main
()
{
local
cmd
=
${
1
:-
help
}
case
"
${
cmd
}
"
in
...
...
@@ -71,13 +80,16 @@ main() {
infer
)
infer
"
$@
"
;
;;
save_inference_model
)
save_inference_model
"
$@
"
;
;;
help
)
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer}"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer
|save_inference_model
}"
;
return
0
;
;;
*
)
echo
"Unsupport commend [
${
cmd
}
]"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer}"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer
|save_inference_model
}"
;
return
1
;
;;
esac
...
...
PaddleNLP/sentiment_classification/run_classifier.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sentiment Classification Task
"""
...
...
@@ -34,93 +21,56 @@ from nets import cnn_net
from
nets
import
bilstm_net
from
nets
import
gru_net
from
models.model_check
import
check_cuda
from
config
import
PDConfig
import
paddle
import
paddle.fluid
as
fluid
import
reader
from
config
import
SentaConfig
from
utils
import
ArgumentGroup
,
print_arguments
from
utils
import
init_checkpoint
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"senta_config_path"
,
str
,
None
,
"Path to the json file for senta model config."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"checkpoints"
,
str
,
"checkpoints"
,
"Path to save checkpoints"
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
10
,
"Number of epoches for training."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"lr"
,
float
,
0.002
,
"The Learning rate value for training."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related"
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log"
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
.
add_arg
(
"data_dir"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
None
,
"Vocabulary path."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
256
,
"Total examples' number in batch for training."
)
data_g
.
add_arg
(
"random_seed"
,
int
,
0
,
"Random seed."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"task_name"
,
str
,
None
,
"The name of task to perform sentiment classification."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_val"
,
bool
,
True
,
"Whether to perform evaluation."
)
run_type_g
.
add_arg
(
"do_infer"
,
bool
,
True
,
"Whether to perform inference."
)
parser
.
add_argument
(
'--enable_ce'
,
action
=
'store_true'
,
help
=
'If set, run the task with continuous evaluation logs.'
)
args
=
parser
.
parse_args
()
# yapf: enable.
def
create_model
(
args
,
pyreader_name
,
senta_config
,
num_labels
,
is_
inference
=
False
):
is_
prediction
=
False
):
"""
Create Model for sentiment classification
"""
pyreader
=
fluid
.
layers
.
py_reader
(
capacity
=
16
,
shapes
=
([
-
1
,
1
],
[
-
1
,
1
]),
dtypes
=
(
'int64'
,
'int64'
),
lod_levels
=
(
1
,
0
),
name
=
pyreader_name
,
use_double_buffer
=
False
)
data
=
fluid
.
layers
.
data
(
name
=
"src_ids"
,
shape
=
[
-
1
,
args
.
max_seq_len
,
1
],
dtype
=
'int64'
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
)
seq_len
=
fluid
.
layers
.
data
(
name
=
"seq_len"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
)
data_reader
=
fluid
.
io
.
PyReader
(
feed_list
=
[
data
,
label
,
seq_len
],
capacity
=
4
,
iterable
=
False
)
if
senta_config
[
'model_type'
]
==
"bilstm_net"
:
if
args
.
model_type
==
"bilstm_net"
:
network
=
bilstm_net
elif
senta_config
[
'model_type'
]
==
"bow_net"
:
elif
args
.
model_type
==
"bow_net"
:
network
=
bow_net
elif
senta_config
[
'model_type'
]
==
"cnn_net"
:
elif
args
.
model_type
==
"cnn_net"
:
network
=
cnn_net
elif
senta_config
[
'model_type'
]
==
"lstm_net"
:
elif
args
.
model_type
==
"lstm_net"
:
network
=
lstm_net
elif
senta_config
[
'model_type'
]
==
"gru_net"
:
elif
args
.
model_type
==
"gru_net"
:
network
=
gru_net
else
:
raise
ValueError
(
"Unknown network type!"
)
if
is_inference
:
data
,
label
=
fluid
.
layers
.
read_file
(
pyreader
)
probs
=
network
(
data
,
None
,
senta_config
[
"vocab_size"
],
is_infer
=
is_inference
)
if
is_prediction
:
probs
=
network
(
data
,
seq_len
,
None
,
args
.
vocab_size
,
is_prediction
=
is_prediction
)
print
(
"create inference model..."
)
return
pyreader
,
probs
return
data_reader
,
probs
,
[
data
.
name
,
seq_len
.
name
]
data
,
label
=
fluid
.
layers
.
read_file
(
pyreader
)
ce_loss
,
probs
=
network
(
data
,
label
,
senta_config
[
"vocab_size"
],
is_infer
=
is_inference
)
ce_loss
,
probs
=
network
(
data
,
seq_len
,
label
,
args
.
vocab_size
,
is_prediction
=
is_prediction
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
num_seqs
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
probs
,
label
=
label
,
total
=
num_seqs
)
return
py
reader
,
loss
,
accuracy
,
num_seqs
return
data_
reader
,
loss
,
accuracy
,
num_seqs
...
...
@@ -132,6 +82,7 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
total_cost
,
total_acc
,
total_num_seqs
=
[],
[],
[]
time_begin
=
time
.
time
()
while
True
:
#print("===============")
try
:
np_loss
,
np_acc
,
np_num_seqs
=
exe
.
run
(
program
=
test_program
,
fetch_list
=
fetch_list
,
...
...
@@ -174,8 +125,6 @@ def main(args):
"""
Main Function
"""
senta_config
=
SentaConfig
(
args
.
senta_config_path
)
if
args
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
int
(
os
.
getenv
(
'FLAGS_selected_gpus'
,
'0'
)))
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
...
...
@@ -187,10 +136,10 @@ def main(args):
task_name
=
args
.
task_name
.
lower
()
processor
=
reader
.
SentaProcessor
(
data_dir
=
args
.
data_dir
,
vocab_path
=
args
.
vocab_path
,
random_seed
=
args
.
random_seed
)
random_seed
=
args
.
random_seed
,
max_seq_len
=
args
.
max_seq_len
)
num_labels
=
len
(
processor
.
get_labels
())
if
not
(
args
.
do_train
or
args
.
do_val
or
args
.
do_infer
):
raise
ValueError
(
"For args `do_train`, `do_val` and `do_infer`, at "
"least one of them must be True."
)
...
...
@@ -220,12 +169,11 @@ def main(args):
with
fluid
.
program_guard
(
train_program
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
train_
py
reader
,
loss
,
accuracy
,
num_seqs
=
create_model
(
train_reader
,
loss
,
accuracy
,
num_seqs
=
create_model
(
args
,
pyreader_name
=
'train_reader'
,
senta_config
=
senta_config
,
num_labels
=
num_labels
,
is_
inference
=
False
)
is_
prediction
=
False
)
sgd_optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
args
.
lr
)
sgd_optimizer
.
minimize
(
loss
)
...
...
@@ -237,28 +185,36 @@ def main(args):
(
lower_mem
,
upper_mem
,
unit
))
if
args
.
do_val
:
test_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
)
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
test_
py
reader
,
loss
,
accuracy
,
num_seqs
=
create_model
(
test_reader
,
loss
,
accuracy
,
num_seqs
=
create_model
(
args
,
pyreader_name
=
'test_reader'
,
senta_config
=
senta_config
,
num_labels
=
num_labels
,
is_
inference
=
False
)
is_
prediction
=
False
)
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
if
args
.
do_infer
:
infer_data_generator
=
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
)
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_
pyreader
,
prop
=
create_model
(
infer_
reader
,
prop
,
_
=
create_model
(
args
,
pyreader_name
=
'infer_reader'
,
senta_config
=
senta_config
,
num_labels
=
num_labels
,
is_
inference
=
True
)
is_
prediction
=
True
)
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
exe
.
run
(
startup_prog
)
...
...
@@ -281,14 +237,18 @@ def main(args):
if
args
.
do_train
:
train_exe
=
exe
train_
pyreader
.
decorate_paddle_reade
r
(
train_data_generator
)
train_
reader
.
decorate_sample_list_generato
r
(
train_data_generator
)
else
:
train_exe
=
None
if
args
.
do_val
or
args
.
do_infer
:
if
args
.
do_val
:
test_exe
=
exe
test_reader
.
decorate_sample_list_generator
(
test_data_generator
)
if
args
.
do_infer
:
test_exe
=
exe
infer_reader
.
decorate_sample_list_generator
(
infer_data_generator
)
if
args
.
do_train
:
train_
py
reader
.
start
()
train_reader
.
start
()
steps
=
0
total_cost
,
total_acc
,
total_num_seqs
=
[],
[],
[]
time_begin
=
time
.
time
()
...
...
@@ -335,55 +295,32 @@ def main(args):
# evaluate dev set
if
args
.
do_val
:
print
(
"do evalatation"
)
test_pyreader
.
decorate_paddle_reader
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
evaluate
(
exe
,
test_prog
,
test_reader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
except
fluid
.
core
.
EOFException
:
save_path
=
os
.
path
.
join
(
args
.
checkpoints
,
"step_"
+
str
(
steps
))
fluid
.
io
.
save_persistables
(
exe
,
save_path
,
train_program
)
train_
py
reader
.
reset
()
train_reader
.
reset
()
break
# final eval on dev set
if
args
.
do_val
:
test_pyreader
.
decorate_paddle_reader
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
))
print
(
"Final validation result:"
)
evaluate
(
exe
,
test_prog
,
test_
py
reader
,
evaluate
(
exe
,
test_prog
,
test_reader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
test_pyreader
.
decorate_paddle_reader
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"infer"
)
# final eval on test set
if
args
.
do_infer
:
infer_pyreader
.
decorate_paddle_reader
(
processor
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
))
print
(
"Final test result:"
)
inference
(
exe
,
infer_prog
,
infer_
py
reader
,
inference
(
exe
,
infer_prog
,
infer_reader
,
[
prop
.
name
],
"infer"
)
if
__name__
==
"__main__"
:
print_arguments
(
args
)
args
=
PDConfig
(
'senta_config.json'
)
args
.
build
()
args
.
print_arguments
()
check_cuda
(
args
.
use_cuda
)
main
(
args
)
PaddleNLP/sentiment_classification/run_ernie.sh
浏览文件 @
f97790a0
...
...
@@ -2,9 +2,9 @@
export
FLAGS_fraction_of_gpu_memory_to_use
=
0.95
export
FLAGS_enable_parallel_graph
=
1
export
FLAGS_sync_nccl_allreduce
=
1
export
CUDA_VISIBLE_DEVICES
=
3
export
CUDA_VISIBLE_DEVICES
=
12
export
CPU_NUM
=
1
ERNIE_PRETRAIN
=
./
senta_model/
ernie_pretrain_model/
ERNIE_PRETRAIN
=
./ernie_pretrain_model/
DATA_PATH
=
./senta_data
MODEL_SAVE_PATH
=
./save_models
...
...
@@ -17,7 +17,7 @@ train() {
--do_val
true
\
--do_infer
false
\
--use_paddle_hub
false
\
--batch_size
2
4
\
--batch_size
4
\
--init_checkpoint
$ERNIE_PRETRAIN
/params
\
--train_set
$DATA_PATH
/train.tsv
\
--dev_set
$DATA_PATH
/dev.tsv
\
...
...
@@ -25,8 +25,8 @@ train() {
--vocab_path
$ERNIE_PRETRAIN
/vocab.txt
\
--checkpoints
$MODEL_SAVE_PATH
\
--save_steps
5000
\
--validation_steps
1
00
\
--epoch
10
\
--validation_steps
50
00
\
--epoch
2
\
--max_seq_len
256
\
--ernie_config_path
$ERNIE_PRETRAIN
/ernie_config.json
\
--model_type
"ernie_base"
\
...
...
@@ -45,8 +45,8 @@ evaluate() {
--do_val
true
\
--do_infer
false
\
--use_paddle_hub
false
\
--batch_size
2
4
\
--init_checkpoint
./save_models/step_
5000
/
\
--batch_size
4
\
--init_checkpoint
./save_models/step_
4801
/
\
--dev_set
$DATA_PATH
/dev.tsv
\
--vocab_path
$ERNIE_PRETRAIN
/vocab.txt
\
--max_seq_len
256
\
...
...
@@ -61,8 +61,8 @@ evaluate() {
--do_val
true
\
--do_infer
false
\
--use_paddle_hub
false
\
--batch_size
2
4
\
--init_checkpoint
./save_models/step_
5000
/
\
--batch_size
4
\
--init_checkpoint
./save_models/step_
4801
/
\
--dev_set
$DATA_PATH
/test.tsv
\
--vocab_path
$ERNIE_PRETRAIN
/vocab.txt
\
--max_seq_len
256
\
...
...
@@ -80,8 +80,8 @@ infer() {
--do_val
false
\
--do_infer
true
\
--use_paddle_hub
false
\
--batch_size
2
4
\
--init_checkpoint
./save_models/step_
5000
\
--batch_size
4
\
--init_checkpoint
./save_models/step_
4801
\
--test_set
$DATA_PATH
/test.tsv
\
--vocab_path
$ERNIE_PRETRAIN
/vocab.txt
\
--max_seq_len
256
\
...
...
@@ -90,6 +90,20 @@ infer() {
--num_labels
2
}
# run_save_inference_model
save_inference_model
()
{
python
-u
inference_model_ernie.py
\
--use_cuda
true
\
--do_save_inference_model
true
\
--init_checkpoint
./save_models/step_4801/
\
--inference_model_dir
./inference_model
\
--ernie_config_path
$ERNIE_PRETRAIN
/ernie_config.json
\
--model_type
"ernie_base"
\
--vocab_path
$ERNIE_PRETRAIN
/vocab.txt
\
--test_set
${
DATA_PATH
}
/test.tsv
\
--batch_size
4
}
main
()
{
local
cmd
=
${
1
:-
help
}
case
"
${
cmd
}
"
in
...
...
@@ -102,13 +116,16 @@ main() {
infer
)
infer
"
$@
"
;
;;
save_inference_model
)
save_inference_model
"
$@
"
;
;;
help
)
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer}"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer
|save_inference_model
}"
;
return
0
;
;;
*
)
echo
"Unsupport commend [
${
cmd
}
]"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer}"
;
echo
"Usage:
${
BASH_SOURCE
}
{train|eval|infer
|save_inference_model
}"
;
return
1
;
;;
esac
...
...
PaddleNLP/sentiment_classification/run_ernie_classifier.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sentiment Classification Task
"""
...
...
@@ -43,54 +30,35 @@ from nets import ernie_bilstm_net
from
preprocess.ernie
import
task_reader
from
models.representation.ernie
import
ErnieConfig
from
models.representation.ernie
import
ernie_encoder
,
ernie_encoder_with_paddle_hub
from
models.representation.ernie
import
ernie_pyreader
#
from models.representation.ernie import ernie_pyreader
from
models.model_check
import
check_cuda
from
utils
import
ArgumentGroup
from
utils
import
print_arguments
from
config
import
PDConfig
from
utils
import
init_checkpoint
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"ernie_config_path"
,
str
,
None
,
"Path to the json file for ernie model config."
)
model_g
.
add_arg
(
"senta_config_path"
,
str
,
None
,
"Path to the json file for senta model config."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"checkpoints"
,
str
,
"checkpoints"
,
"Path to save checkpoints"
)
model_g
.
add_arg
(
"model_type"
,
str
,
"ernie_base"
,
"Type of current ernie model"
)
model_g
.
add_arg
(
"use_paddle_hub"
,
bool
,
False
,
"Whether to load ERNIE using PaddleHub"
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
10
,
"Number of epoches for training."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
10000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"validation_steps"
,
int
,
1000
,
"The steps interval to evaluate model performance."
)
train_g
.
add_arg
(
"lr"
,
float
,
0.002
,
"The Learning rate value for training."
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related"
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log"
)
data_g
=
ArgumentGroup
(
parser
,
"data"
,
"Data paths, vocab paths and data processing options"
)
data_g
.
add_arg
(
"data_dir"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"vocab_path"
,
str
,
None
,
"Vocabulary path."
)
data_g
.
add_arg
(
"batch_size"
,
int
,
256
,
"Total examples' number in batch for training."
)
data_g
.
add_arg
(
"random_seed"
,
int
,
0
,
"Random seed."
)
data_g
.
add_arg
(
"num_labels"
,
int
,
2
,
"label number"
)
data_g
.
add_arg
(
"max_seq_len"
,
int
,
512
,
"Number of words of the longest seqence."
)
data_g
.
add_arg
(
"train_set"
,
str
,
None
,
"Path to training data."
)
data_g
.
add_arg
(
"test_set"
,
str
,
None
,
"Path to test data."
)
data_g
.
add_arg
(
"dev_set"
,
str
,
None
,
"Path to validation data."
)
data_g
.
add_arg
(
"label_map_config"
,
str
,
None
,
"label_map_path."
)
data_g
.
add_arg
(
"do_lower_case"
,
bool
,
True
,
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_val"
,
bool
,
True
,
"Whether to perform evaluation."
)
run_type_g
.
add_arg
(
"do_infer"
,
bool
,
True
,
"Whether to perform inference."
)
args
=
parser
.
parse_args
()
# yapf: enable.
def
ernie_pyreader
(
args
,
pyreader_name
):
src_ids
=
fluid
.
layers
.
data
(
name
=
"src_ids"
,
shape
=
[
-
1
,
args
.
max_seq_len
,
1
],
dtype
=
"int64"
)
sent_ids
=
fluid
.
layers
.
data
(
name
=
"sent_ids"
,
shape
=
[
-
1
,
args
.
max_seq_len
,
1
],
dtype
=
"int64"
)
pos_ids
=
fluid
.
layers
.
data
(
name
=
"pos_ids"
,
shape
=
[
-
1
,
args
.
max_seq_len
,
1
],
dtype
=
"int64"
)
input_mask
=
fluid
.
layers
.
data
(
name
=
"input_mask"
,
shape
=
[
-
1
,
args
.
max_seq_len
,
1
],
dtype
=
"float32"
)
labels
=
fluid
.
layers
.
data
(
name
=
"labels"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
)
seq_lens
=
fluid
.
layers
.
data
(
name
=
"seq_lens"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
)
pyreader
=
fluid
.
io
.
PyReader
(
feed_list
=
[
src_ids
,
sent_ids
,
pos_ids
,
input_mask
,
labels
,
seq_lens
],
capacity
=
4
,
iterable
=
False
)
ernie_inputs
=
{
"src_ids"
:
src_ids
,
"sent_ids"
:
sent_ids
,
"pos_ids"
:
pos_ids
,
"input_mask"
:
input_mask
,
"seq_lens"
:
seq_lens
}
return
pyreader
,
ernie_inputs
,
labels
def
create_model
(
args
,
embeddings
,
...
...
@@ -174,7 +142,6 @@ def main(args):
"""
Main Function
"""
args
=
parser
.
parse_args
()
ernie_config
=
ErnieConfig
(
args
.
ernie_config_path
)
ernie_config
.
print_config
()
...
...
@@ -224,7 +191,7 @@ def main(args):
# create ernie_pyreader
train_pyreader
,
ernie_inputs
,
labels
=
ernie_pyreader
(
args
,
pyreader_name
=
'train_reader'
)
pyreader_name
=
'train_
py
reader'
)
# get ernie_embeddings
if
args
.
use_paddle_hub
:
...
...
@@ -239,10 +206,6 @@ def main(args):
labels
=
labels
,
is_prediction
=
False
)
"""
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr)
sgd_optimizer.minimize(loss)
"""
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
lr
)
optimizer
.
minimize
(
loss
)
...
...
@@ -253,6 +216,12 @@ def main(args):
(
lower_mem
,
upper_mem
,
unit
))
if
args
.
do_val
:
test_data_generator
=
reader
.
data_generator
(
input_file
=
args
.
dev_set
,
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
)
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
...
...
@@ -277,12 +246,18 @@ def main(args):
test_prog
=
test_prog
.
clone
(
for_test
=
True
)
if
args
.
do_infer
:
infer_data_generator
=
reader
.
data_generator
(
input_file
=
args
.
test_set
,
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
)
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
infer_pyreader
,
ernie_inputs
,
labels
=
ernie_pyreader
(
args
,
pyreader_name
=
"infer_reader"
)
pyreader_name
=
"infer_
py
reader"
)
# get ernie_embeddings
if
args
.
use_paddle_hub
:
...
...
@@ -323,20 +298,16 @@ def main(args):
main_program
=
infer_prog
)
if
args
.
do_train
:
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
num_iteration_per_drop_scope
=
1
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
args
.
use_cuda
,
loss_name
=
loss
.
name
,
exec_strategy
=
exec_strategy
,
main_program
=
train_program
)
train_pyreader
.
decorate_tensor_provider
(
train_data_generator
)
train_exe
=
exe
train_pyreader
.
decorate_batch_generator
(
train_data_generator
)
else
:
train_exe
=
None
if
args
.
do_val
or
args
.
do_infer
:
if
args
.
do_val
:
test_exe
=
exe
test_pyreader
.
decorate_batch_generator
(
test_data_generator
)
if
args
.
do_infer
:
test_exe
=
exe
infer_pyreader
.
decorate_batch_generator
(
infer_data_generator
)
if
args
.
do_train
:
train_pyreader
.
start
()
...
...
@@ -351,7 +322,7 @@ def main(args):
else
:
fetch_list
=
[]
outputs
=
train_exe
.
run
(
fetch_list
=
fetch_list
,
return_numpy
=
False
)
outputs
=
train_exe
.
run
(
program
=
train_program
,
fetch_list
=
fetch_list
,
return_numpy
=
False
)
if
steps
%
args
.
skip_steps
==
0
:
np_loss
,
np_acc
,
np_num_seqs
=
outputs
np_loss
=
np
.
array
(
np_loss
)
...
...
@@ -383,30 +354,10 @@ def main(args):
if
steps
%
args
.
validation_steps
==
0
:
# evaluate dev set
if
args
.
do_val
:
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
input_file
=
args
.
dev_set
,
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
input_file
=
args
.
test_set
,
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
))
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"infer"
)
except
fluid
.
core
.
EOFException
:
save_path
=
os
.
path
.
join
(
args
.
checkpoints
,
"step_"
+
str
(
steps
))
fluid
.
io
.
save_persistables
(
exe
,
save_path
,
train_program
)
...
...
@@ -415,29 +366,19 @@ def main(args):
# final eval on dev set
if
args
.
do_val
:
test_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
input_file
=
args
.
dev_set
,
batch_size
=
args
.
batch_size
,
phase
=
'dev'
,
epoch
=
1
,
shuffle
=
False
))
print
(
"Final validation result:"
)
evaluate
(
exe
,
test_prog
,
test_pyreader
,
[
loss
.
name
,
accuracy
.
name
,
num_seqs
.
name
],
"dev"
)
# final eval on test set
if
args
.
do_infer
:
infer_pyreader
.
decorate_tensor_provider
(
reader
.
data_generator
(
input_file
=
args
.
test_set
,
batch_size
=
args
.
batch_size
,
phase
=
'infer'
,
epoch
=
1
,
shuffle
=
False
))
print
(
"Final test result:"
)
infer
(
exe
,
infer_prog
,
infer_pyreader
,
[
probs
.
name
],
"infer"
)
if
__name__
==
"__main__"
:
print_arguments
(
args
)
args
=
PDConfig
()
args
.
build
()
args
.
print_arguments
()
check_cuda
(
args
.
use_cuda
)
main
(
args
)
PaddleNLP/sentiment_classification/senta_config.json
浏览文件 @
f97790a0
{
"model_type"
:
"bilstm_net"
,
"vocab_size"
:
33256
"vocab_size"
:
33256
,
"max_seq_len"
:
256
}
PaddleNLP/sentiment_classification/utils.py
浏览文件 @
f97790a0
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Arguments for configuration
"""
...
...
@@ -44,7 +31,6 @@ class ArgumentGroup(object):
"""
Argument Class
"""
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
...
...
@@ -94,11 +80,12 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
print
(
"Load model from {}"
.
format
(
init_checkpoint_path
))
def
data_reader
(
file_path
,
word_dict
,
num_examples
,
phrase
,
epoch
):
def
data_reader
(
file_path
,
word_dict
,
num_examples
,
phrase
,
epoch
,
max_seq_len
):
"""
Convert word sequence into slot
"""
unk_id
=
len
(
word_dict
)
pad_id
=
0
all_data
=
[]
with
io
.
open
(
file_path
,
"r"
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
...
...
@@ -109,11 +96,16 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch):
sys
.
stderr
.
write
(
"[NOTICE] Error Format Line!"
)
continue
label
=
int
(
cols
[
1
])
wids
=
[
word_dict
[
x
]
if
x
in
word_dict
else
unk_id
for
x
in
cols
[
0
].
split
(
" "
)
]
all_data
.
append
((
wids
,
label
))
wids
=
[
word_dict
[
x
]
if
x
in
word_dict
else
unk_id
for
x
in
cols
[
0
].
split
(
" "
)]
seq_len
=
len
(
wids
)
if
seq_len
<
max_seq_len
:
for
i
in
range
(
max_seq_len
-
seq_len
):
wids
.
append
(
pad_id
)
else
:
wids
=
wids
[:
max_seq_len
]
seq_len
=
max_seq_len
all_data
.
append
((
wids
,
label
,
seq_len
))
if
phrase
==
"train"
:
random
.
shuffle
(
all_data
)
...
...
@@ -125,12 +117,10 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch):
Reader Function
"""
for
epoch_index
in
range
(
epoch
):
for
doc
,
label
in
all_data
:
yield
doc
,
label
for
doc
,
label
,
seq_len
in
all_data
:
yield
doc
,
label
,
seq_len
return
reader
def
load_vocab
(
file_path
):
"""
load the given vocabulary
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录