未验证 提交 f97790a0 编写于 作者: L LiuHao 提交者: GitHub

情感分析代码易用性优化 (#3420)

* update

* update

* update

* fix run_ernie.sh

* Update README.md
上级 e5c84957
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This module provide nets for text classification This module provide nets for text classification
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
def bow_net(data, def bow_net(data,
seq_len,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
hid_dim=128, hid_dim=128,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
is_infer=False): is_prediction=False):
""" """
Bow net Bow net
""" """
# embedding layer # embedding layer
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
# bow layer # bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow) bow_tanh = fluid.layers.tanh(bow)
...@@ -39,7 +27,7 @@ def bow_net(data, ...@@ -39,7 +27,7 @@ def bow_net(data,
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -49,6 +37,7 @@ def bow_net(data, ...@@ -49,6 +37,7 @@ def bow_net(data,
def cnn_net(data, def cnn_net(data,
seq_len,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
...@@ -56,13 +45,13 @@ def cnn_net(data, ...@@ -56,13 +45,13 @@ def cnn_net(data,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
win_size=3, win_size=3,
is_infer=False): is_prediction=False):
""" """
Conv net Conv net
""" """
# embedding layer # embedding layer
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
# convolution layer # convolution layer
conv_3 = fluid.nets.sequence_conv_pool( conv_3 = fluid.nets.sequence_conv_pool(
input=emb, input=emb,
...@@ -75,7 +64,7 @@ def cnn_net(data, ...@@ -75,7 +64,7 @@ def cnn_net(data,
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax") prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -85,6 +74,7 @@ def cnn_net(data, ...@@ -85,6 +74,7 @@ def cnn_net(data,
def lstm_net(data, def lstm_net(data,
seq_len,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
...@@ -92,7 +82,7 @@ def lstm_net(data, ...@@ -92,7 +82,7 @@ def lstm_net(data,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
emb_lr=30.0, emb_lr=30.0,
is_infer=False): is_prediction=False):
""" """
Lstm net Lstm net
""" """
...@@ -101,7 +91,7 @@ def lstm_net(data, ...@@ -101,7 +91,7 @@ def lstm_net(data,
input=data, input=data,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr)) param_attr=fluid.ParamAttr(learning_rate=emb_lr))
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
# Lstm layer # Lstm layer
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4) fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
...@@ -116,7 +106,7 @@ def lstm_net(data, ...@@ -116,7 +106,7 @@ def lstm_net(data,
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh') fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -126,6 +116,7 @@ def lstm_net(data, ...@@ -126,6 +116,7 @@ def lstm_net(data,
def bilstm_net(data, def bilstm_net(data,
seq_len,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
...@@ -133,7 +124,7 @@ def bilstm_net(data, ...@@ -133,7 +124,7 @@ def bilstm_net(data,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
emb_lr=30.0, emb_lr=30.0,
is_infer=False): is_prediction=False):
""" """
Bi-Lstm net Bi-Lstm net
""" """
...@@ -142,6 +133,8 @@ def bilstm_net(data, ...@@ -142,6 +133,8 @@ def bilstm_net(data,
input=data, input=data,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr)) param_attr=fluid.ParamAttr(learning_rate=emb_lr))
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4) fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4) rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
...@@ -161,7 +154,7 @@ def bilstm_net(data, ...@@ -161,7 +154,7 @@ def bilstm_net(data,
fc1 = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh') fc1 = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh')
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -170,6 +163,7 @@ def bilstm_net(data, ...@@ -170,6 +163,7 @@ def bilstm_net(data,
def gru_net(data, def gru_net(data,
seq_len,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
...@@ -177,7 +171,7 @@ def gru_net(data, ...@@ -177,7 +171,7 @@ def gru_net(data,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
emb_lr=30.0, emb_lr=30.0,
is_infer=False): is_prediction=False):
""" """
gru net gru net
""" """
...@@ -185,7 +179,7 @@ def gru_net(data, ...@@ -185,7 +179,7 @@ def gru_net(data,
input=data, input=data,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr)) param_attr=fluid.ParamAttr(learning_rate=emb_lr))
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3) fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False) gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
...@@ -196,7 +190,7 @@ def gru_net(data, ...@@ -196,7 +190,7 @@ def gru_net(data,
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh') fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -206,14 +200,15 @@ def gru_net(data, ...@@ -206,14 +200,15 @@ def gru_net(data,
def textcnn_net(data, def textcnn_net(data,
label, seq_len,
dict_dim, label,
emb_dim=128, dict_dim,
hid_dim=128, emb_dim=128,
hid_dim2=96, hid_dim=128,
class_dim=2, hid_dim2=96,
win_sizes=None, class_dim=2,
is_infer=False): win_sizes=None,
is_prediction=False):
""" """
Textcnn_net Textcnn_net
""" """
...@@ -222,7 +217,7 @@ def textcnn_net(data, ...@@ -222,7 +217,7 @@ def textcnn_net(data,
# embedding layer # embedding layer
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
# convolution layer # convolution layer
convs = [] convs = []
for win_size in win_sizes: for win_size in win_sizes:
...@@ -239,7 +234,7 @@ def textcnn_net(data, ...@@ -239,7 +234,7 @@ def textcnn_net(data,
fc_1 = fluid.layers.fc(input=[convs_out], size=hid_dim2, act="tanh") fc_1 = fluid.layers.fc(input=[convs_out], size=hid_dim2, act="tanh")
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax") prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
if is_infer: if is_prediction:
return prediction return prediction
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
......
...@@ -27,7 +27,17 @@ import numpy as np ...@@ -27,7 +27,17 @@ import numpy as np
from preprocess.ernie import tokenization from preprocess.ernie import tokenization
from preprocess.padding import pad_batch_data from preprocess.padding import pad_batch_data
import io
def csv_reader(fd, delimiter='\t'):
def gen():
for i in fd:
slots = i.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
class BaseReader(object): class BaseReader(object):
"""BaseReader for classify and sequence labeling task""" """BaseReader for classify and sequence labeling task"""
...@@ -66,8 +76,8 @@ class BaseReader(object): ...@@ -66,8 +76,8 @@ class BaseReader(object):
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r", encoding="utf8") as f: with io.open(input_file, "r", encoding="utf8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv_reader(f, delimiter="\t")
headers = next(reader) headers = next(reader)
Example = namedtuple('Example', headers) Example = namedtuple('Example', headers)
...@@ -228,8 +238,8 @@ class ClassifyReader(BaseReader): ...@@ -228,8 +238,8 @@ class ClassifyReader(BaseReader):
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r", encoding="utf8") as f: with io.open(input_file, "r", encoding="utf8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv_reader(f, delimiter="\t")
headers = next(reader) headers = next(reader)
text_indices = [ text_indices = [
index for index, h in enumerate(headers) if h != "label" index for index, h in enumerate(headers) if h != "label"
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import collections import collections
import unicodedata import unicodedata
import six import six
import io
def convert_to_unicode(text): def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
...@@ -69,7 +69,7 @@ def printable_text(text): ...@@ -69,7 +69,7 @@ def printable_text(text):
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
fin = open(vocab_file, encoding="utf8") fin = io.open(vocab_file, encoding="utf8")
for num, line in enumerate(fin): for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t") items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2: if len(items) > 2:
......
## 简介 # 情感倾向分析
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。可通过[AI开放平台-情感倾向分析](http://ai.baidu.com/tech/nlp_apply/sentiment_classify) 线上体验。 * [模型简介](#模型简介)
* [快速开始](#快速开始)
* [进阶使用](#进阶使用)
* [版本更新](#版本更新)
* [作者](#作者)
* [如何贡献代码](#如何贡献代码)
## 模型简介
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。可通过 [AI开放平台-情感倾向分析](http://ai.baidu.com/tech/nlp_apply/sentiment_classify) 线上体验。
情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测;此外,我们还开源了百度基于海量数据训练好的模型,该模型在ChnSentiCorp数据集上fine-tune之后(基于开源模型进行Finetune的方法请见下面章节),可以得到更好的效果。具体数据如下所示: 情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测;此外,我们还开源了百度基于海量数据训练好的模型,该模型在ChnSentiCorp数据集上fine-tune之后(基于开源模型进行Finetune的方法请见下面章节),可以得到更好的效果。具体数据如下所示:
...@@ -14,49 +23,113 @@ ...@@ -14,49 +23,113 @@
| ERNIE | 95.1% | 95.4% | ERNIE |95.4% | 95.5% | | ERNIE | 95.1% | 95.4% | ERNIE |95.4% | 95.5% |
| ERNIE+BI-LSTM | 95.3% | 95.2% | ERNIE+BI-LSTM |95.7% | 95.6% | | ERNIE+BI-LSTM | 95.3% | 95.2% | ERNIE+BI-LSTM |95.7% | 95.6% |
## 快速开始
### 安装说明
## 快速开始 1. PaddlePaddle 安装
本项目依赖于 PaddlePaddle Fluid 1.3.2 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 代码安装
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/sentiment_classification
```
3. 环境依赖
本项目依赖于 Paddlepaddle 1.3.2 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容
python版本依赖python 2.7 ### 代码结构说明
以下是本项目主要代码结构及说明:
```text
.
├── senta_config.json # 配置文件
├── config.py # 配置文件读取接口
├── inference_model.py # 保存 inference_model 的脚本
├── inference_ernie_model.py # 保存 inference_ernie__model 的脚本
├── reader.py # 数据读取接口
├── run_classifier.py # 项目的主程序入口,包括训练、预测、评估
├── run.sh # 训练、预测、评估运行脚本
├── run_ernie_classifier.py # 基于ERNIE表示的项目的主程序入口
├── run_ernie.sh # 基于ERNIE的训练、预测、评估运行脚本
├── utils.py # 其它功能函数脚本
```
### 数据准备
#### **自定义数据**
训练、预测、评估使用的数据可以由用户根据实际的应用场景,自己组织数据。数据由两列组成,以制表符分隔,第一列是以空格分词的中文文本(分词预处理方法将在下文具体说明),文件为utf8编码;第二列是情感倾向分类的类别(0表示消极;1表示积极),注意数据文件第一行固定表示为"text_a\tlabel"
注意:该模型同时支持cpu和gpu训练和预测,用户可以根据自身需求,选择安装对应的paddlepaddle-gpu或paddlepaddle版本。 ```text
特 喜欢 这种 好看的 狗狗 1
这 真是 惊艳 世界 的 中国 黑科技 1
环境 特别 差 ,脏兮兮 的,再也 不去 了 0
```
#### 安装代码 注:PaddleNLP 项目提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如下:
克隆数据集代码库到本地
```shell ```shell
git clone https://github.com/PaddlePaddle/models.git python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.utf8.seg
cd models/PaddleNLP/sentiment_classification #其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中。
``` ```
#### 数据准备 #### 公开数据集
下载经过预处理的数据,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt) 下载经过预处理的数据,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt)
```shell ```shell
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz
tar -zxvf sentiment_classification-dataset-1.0.0.tar.gz tar -zxvf sentiment_classification-dataset-1.0.0.tar.gz
``` ```
#### 模型下载 ```text
.
├── train.tsv # 训练集
├── train.tsv # 验证集
├── test.tsv # 测试集
├── word_dict.txt # 词典
```
我们开源了基于ChnSentiCorp数据训练的情感倾向性分类模型(基于BOW、CNN、LSTM、ERNIE多种模型训练),可供用户直接使用。我们提供了两种下载方式: ### 单机训练
方式一:基于PaddleHub命令行工具(PaddleHub安装方式 https://github.com/PaddlePaddle/PaddleHub ) 基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
```shell ```shell
hub download sentiment_classification --output_path ./ # BOW、CNN、LSTM、BI-LSTM、GRU模型
tar -zxvf sentiment_classification-1.0.0.tar.gz sh run.sh train
# ERNIE、ERNIE+BI-LSTM模型
sh run_ernie.sh train
``` ```
训练完成后,可修改```run.sh```中init_checkpoint参数,进行模型评估和预测
方式二:直接下载
```shell
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-1.0.0.tar.gz
tar -zxvf sentiment_classification-1.0.0.tar.gz
``` ```
"""
#### 模型评估 # 输出结果示例
Running type options:
--do_train DO_TRAIN Whether to perform training. Default: False.
...
Model config options:
--model_type {bow_net,cnn_net,lstm_net,bilstm_net,gru_net,textcnn_net}
Model type to run the task. Default: textcnn_net.
--init_checkpoint INIT_CHECKPOINT
Init checkpoint to resume training from. Default: .
--save_checkpoint_dir SAVE_CHECKPOINT_DIR
Directory path to save checkpoints Default: .
...
"""
```
本项目参数控制优先级:命令行参数 > ```config.json ``` > 默认值。训练完成后,会在```./save_models``` 目录下生成以 ```step_xxx ``` 命名的模型目录。
### 模型评估
基于上面的预训练模型和数据,可以运行下面的命令进行测试,查看预训练模型在开发集(dev.tsv)上的评测效果 基于上面的预训练模型和数据,可以运行下面的命令进行测试,查看预训练模型在开发集(dev.tsv)上的评测效果
```shell ```shell
...@@ -83,46 +156,76 @@ senta_config.json中需要修改如下: ...@@ -83,46 +156,76 @@ senta_config.json中需要修改如下:
--init_checkpoint senta_model/ernie_bilstm_model/ --init_checkpoint senta_model/ernie_bilstm_model/
--model_type "ernie_bilstm" --model_type "ernie_bilstm"
``` ```
```
"""
# 输出结果示例
Load model from ./save_models/step_100
Final test result:
[test evaluation] avg loss: 0.339021, avg acc: 0.869691, elapsed time: 0.123983 s
"""
```
我们也提供了使用PaddleHub加载ERNIE模型的选项,PaddleHub是PaddlePaddle的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看[PaddleHub](https://github.com/PaddlePaddle/PaddleHub) ### 模型推断
如果想使用该功能,需要修改run_ernie.sh中的配置如下: 利用已有模型,可以运行下面命令,对未知label的数据(test.tsv)进行预测
```shell ```shell
# 在eval()函数中,修改如下参数: # BOW、CNN、LSTM、BI-LSTM、GRU模型
--use_paddle_hub true sh run.sh infer
#ERNIE+BI-LSTM模型
sh run_ernie.sh infer
``` ```
注意:使用该选项需要先安装PaddleHub,安装命令如下
```
"""
# 输出结果示例
Load model from ./save_models/step_100
1 0.001659 0.998341
0 0.987223 0.012777
1 0.001365 0.998635
1 0.001875 0.998125
"""
```
### 预训练模型
我们开源了基于海量数据训练好的情感倾向分类模型(基于CNN、BI-LSTM、ERNIE等模型训练),可供用户直接使用,我们提供两种下载方式。
**方式一**:基于PaddleHub命令行工具(PaddleHub[安装方式](https://github.com/PaddlePaddle/PaddleHub)
```shell ```shell
$ pip install paddlehub hub download sentiment_classification --output_path ./
tar -zxvf sentiment_classification-1.0.0.tar.gz
``` ```
#### 模型训练 **方式二**:直接下载脚本
基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
```shell ```shell
# BOW、CNN、LSTM、BI-LSTM、GRU模型 wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-1.0.0.tar.gz
sh run.sh train tar -zxvf sentiment_classification-1.0.0.tar.gz
# ERNIE、ERNIE+BI-LSTM模型
sh run_ernie.sh train
``` ```
训练完成后,可修改```run.sh```中init_checkpoint参数,进行模型评估和预测
#### 模型预测 以上两种方式会将预训练的 CNN、BI-LSTM等模型和 ERNIE模型,保存在当前目录下,可直接修改```run.sh```脚本中的```init_checkpoint```参数进行评估、预测。
### 服务部署
为了将模型应用于线上部署,可以利用```inference_model.py``````inference_ernie_model.py``` 脚本对模型进行裁剪,只保存网络参数及裁剪后的模型。运行命令如下:
利用已有模型,可以运行下面命令,对未知label的数据(test.tsv)进行预测
```shell ```shell
# BOW、CNN、LSTM、BI-LSTM、GRU模型 sh run.sh save_inference_model
sh run.sh infer sh run_ernie.sh save_inference_model
#ERNIE+BI-LSTM模型
sh run_ernie.sh infer
``` ```
#### 服务器部署
请参考PaddlePaddle官方提供的 [服务器端部署](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/deploy/inference/index_cn.html) 文档进行部署上线。
## 进阶使用 ## 进阶使用
#### 任务定义 ### 背景介绍
传统的情感分类主要基于词典或者特征工程的方式进行分类,这种方法需要繁琐的人工特征设计和先验知识,理解停留于浅层并且扩展泛化能力差。为了避免传统方法的局限,我们采用近年来飞速发展的深度学习技术。基于深度学习的情感分类不依赖于人工特征,它能够端到端的对输入文本进行语义理解,并基于语义表示进行情感倾向的判断。 传统的情感分类主要基于词典或者特征工程的方式进行分类,这种方法需要繁琐的人工特征设计和先验知识,理解停留于浅层并且扩展泛化能力差。为了避免传统方法的局限,我们采用近年来飞速发展的深度学习技术。基于深度学习的情感分类不依赖于人工特征,它能够端到端的对输入文本进行语义理解,并基于语义表示进行情感倾向的判断。
#### 模型原理介绍
### 模型概览
本项目针对情感倾向性分类问题,开源了一系列模型,供用户可配置地使用: 本项目针对情感倾向性分类问题,开源了一系列模型,供用户可配置地使用:
...@@ -134,66 +237,126 @@ sh run_ernie.sh infer ...@@ -134,66 +237,126 @@ sh run_ernie.sh infer
+ ERNIE(Enhanced Representation through kNowledge IntEgration),百度自研基于海量数据和先验知识训练的通用文本语义表示模型,并基于此在情感倾向分类数据集上进行fine-tune获得。 + ERNIE(Enhanced Representation through kNowledge IntEgration),百度自研基于海量数据和先验知识训练的通用文本语义表示模型,并基于此在情感倾向分类数据集上进行fine-tune获得。
+ ERNIE+BI-LSTM,基于ERNIE语义表示对接上层BI-LSTM模型,并基于此在情感倾向分类数据集上进行Fine-tune获得; + ERNIE+BI-LSTM,基于ERNIE语义表示对接上层BI-LSTM模型,并基于此在情感倾向分类数据集上进行Fine-tune获得;
#### 数据格式说明 ### 自定义模型
训练、预测、评估使用的数据可以由用户根据实际的应用场景,自己组织数据。数据由两列组成,以制表符分隔,第一列是以空格分词的中文文本(分词预处理方法将在下文具体说明),文件为utf8编码;第二列是情感倾向分类的类别(0表示消极;1表示积极),注意数据文件第一行固定表示为"text_a\tlabel" 可以根据自己的需求,组建自定义的模型,具体方法如下所示:
```text 1. 定义自己的网络结构
特 喜欢 这种 好看的 狗狗 1
这 真是 惊艳 世界 的 中国 黑科技 1 用户可以在 ```models/classification/nets.py``` 中,定义自己的模型,只需要增加新的函数即可。假设用户自定义的函数名为```user_net```
环境 特别 差 ,脏兮兮 的,再也 不去 了 0
2. 更改模型配置
```senta_config.json``` 中需要将 ```model_type``` 改为用户自定义的 ```user_net```
3. 模型训练
通过```run.sh``` 脚本运行训练、评估、预测。
### 基于 ERNIE 进行 Finetune
ERNIE 是百度自研的基于海量数据和先验知识训练的通用文本语义表示模型,基于 ERNIE 进行 Finetune,能够提升对话情绪识别的效果。
#### 模型训练
需要先下载 ERNIE 模型,使用如下命令:
```shell
mkdir -p pretrain_models/ernie
cd pretrain_models/ernie
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz -O ERNIE_stable-1.0.1.tar.gz
tar -zxvf ERNIE_stable-1.0.1.tar.gz
``` ```
注:本项目额外提供了分词预处理脚本(在本项目的preprocess目录下),可供用户使用,具体使用方法如下:
然后修改```run_ernie.sh``` 脚本中train 函数的 ```init_checkpoint``` 参数,再执行命令:
```shell ```shell
python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.utf8.seg #--init_checkpoint ./pretrain_models/ernie
sh run_ernie.sh train
```
#其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中。 默认使用GPU进行训练,模型保存在 ```./save_models/ernie/```目录下,以 ```step_xxx ``` 命名。
#### 模型评估
根据训练结果,可选择最优的step进行评估,修改```run_ernie.sh``` 脚本中 eval 函数 ```init_checkpoint``` 参数,然后执行
```shell
#--init_checkpoint./save/step_907
sh run_ernie.sh eval
'''
# 输出结果示例
W0820 14:59:47.811139 334 device_context.cc:259] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0820 14:59:47.815557 334 device_context.cc:267] device: 0, cuDNN Version: 7.3.
Load model from ./save_models/ernie/step_907
Final validation result:
[test evaluation] avg loss: 0.260597, ave acc: 0.907336, elapsed time: 2.383077 s
'''
``` ```
#### 代码结构说明 #### 模型推断
```text 修改```run_ernie.sh``` 脚本中 infer 函数 ```init_checkpoint``` 参数,然后执行
.
├── senta_config.json # 模型配置文件 ```shell
├── config.py # 定义了该项目模型的相关配置,包括具体模型类别、以及模型的超参数 #--init_checkpoint./save/step_907
├── reader.py # 定义了读入数据,加载词典的功能 sh run_ernie.sh infer
├── run_classifier.py # 该项目的主函数,封装包括训练、预测、评估的部分
├── run_ernie_classifier.py # 基于ERNIE表示的项目的主函数 '''
├── run_ernie.sh # 基于ERNIE的训练、预测、评估运行脚本 # 输出结果示例
├── run.sh # 训练、预测、评估运行脚本 Load model from ./save_models/ernie/step_907
├── utils.py # 定义了其他常用的功能函数 Final test result:
1 0.001130 0.998870
0 0.978465 0.021535
1 0.000847 0.999153
1 0.001498 0.998502
'''
``` ```
#### 如何组建自己的模型 ### 基于 PaddleHub 加载 ERNIE 进行 Finetune
可以根据自己的需求,组建自定义的模型,具体方法如下所示: 我们也提供了使用 PaddleHub 加载 ERNIE 模型的选项,PaddleHub 是 PaddlePaddle 的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看 [PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
1. 定义自己的网络结构 注意:使用该选项需要先安装PaddleHub,安装命令如下
用户可以在 ```models/classification/nets.py``` 中,定义自己的模型,只需要增加新的函数即可。假设用户自定义的函数名为```user_net``` ```shell
2. 更改模型配置 pip install paddlehub
```senta_config.json``` 中需要将 ```model_type``` 改为用户自定义的 ```user_net``` ```
3. 模型训练、评估、预测需要在 run.sh 、run_ernie.sh 中将模型、数据、词典路径等配置进行修改
需要修改```run_ernie.sh```中的配置如下:
#### 如何基于百度开源模型进行Finetune
用户可基于百度开源模型在自有数据上实现Finetune训练,以期获得更好的效果提升;如『简介』部分中,我们基于百度开源模型在ChnSentiCorp数据集上Finetune后可以得到更好的效果,具体模型Finetune方法如下所示,如果用户基于开源BI-LSTM模型进行Finetune,需要修改run.sh和senta_config.json文件;
run.sh脚本修改如下:
```shell ```shell
# 在train()函数中,增加--init_checkpoint选项;修改--vocab_path # 在train()函数中,修改--use_paddle_hub选项
--init_checkpoint senta_model/bilstm_model/params --use_paddle_hub true
--vocab_path senta_model/bilstm_model/word_dict.txt
``` ```
senta_config.json中需要修改如下:
执行以下命令进行 Finetune
```shell ```shell
# vob_size大小对应为上面senta_model/bilstm_model//word_dict.txt,词典大小 sh run_ernie.sh train
"vocab_size": 1256606
``` ```
如果用户基于开源ERNIE+BI-LSTM模型进行Finetune,需要更新run_ernie.sh脚本,具体修改如下:
Finetune 结束后,进行 eval 或者 infer 时,需要修改 ```run_ernie.sh``` 中的配置如下:
```shell ```shell
# 在train()函数中,修改--init_checkpoint选项;修改--model_type # 在eval()和infer()函数中,修改--use_paddle_hub选项
--init_checkpoint senta_model/ernie_bilstm_model --use_paddle_hub true
--model_type "ernie_bilstm"
``` ```
执行以下命令进行 eval 和 infer
```shell
sh run_ernie.sh eval
sh run_ernie.sh infer
```
## 版本更新
2019/08/26 规范化配置的使用,对模块内数据处理代码进行了重构,更新README结构,提高易用性。
2019/06/13 添加PaddleHub调用ERNIE方式。
## 作者
- [liuhao](https://github.com/ChinaLiuHao)
## 如何贡献代码 ## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。 如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Senta model. Senta config.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import six import six
import json import json
import numpy as np import argparse
import paddle.fluid as fluid
def str2bool(value):
"""
String to Boolean
"""
# because argparse does not support to parse "True, False" as python
# boolean directly
return value.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
Argument Class
"""
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, dtype, default, help, **kwargs):
"""
Add argument
"""
dtype = str2bool if dtype == bool else dtype
self._group.add_argument(
"--" + name,
default=default,
type=dtype,
help=help + ' Default: %(default)s.',
**kwargs)
class SentaConfig(object): class PDConfig(object):
""" """
Senta Config A high-level api for handling argument configs.
""" """
def __init__(self, json_file=""):
"""
Init function for PDConfig.
json_file: the path to the json configure file.
"""
assert isinstance(json_file, str)
self.args = None
self.arg_config = {}
parser = argparse.ArgumentParser()
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for ernie model config.")
model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints")
model_g.add_arg("model_type", str, "ernie_base", "Type of current ernie model")
model_g.add_arg("use_paddle_hub", bool, False, "Whether to load ERNIE using PaddleHub")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.")
log_g = ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log")
log_g.add_arg('enable_ce', bool, False, 'If set, run the task with continuous evaluation logs.')
def __init__(self, config_path): data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
self._config_dict = self._parse(config_path) data_g.add_arg("data_dir", str, None, "Path to training data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 256, "Total examples' number in batch for training.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
data_g.add_arg("num_labels", int, 2, "label number")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest sequence.")
data_g.add_arg("train_set", str, None, "Path to training data.")
data_g.add_arg("test_set", str, None, "Path to test data.")
data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("label_map_config", str, None, "label_map_path.")
data_g.add_arg("do_lower_case", bool, True, "Whether to lower case the input text. Should be True for uncased models and False for cased models")
def _parse(self, config_path): run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("task_name", str, None,
"The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation.")
run_type_g.add_arg("do_infer", bool, True, "Whether to perform inference.")
run_type_g.add_arg("do_save_inference_model", bool, True, "Whether to save inference model")
run_type_g.add_arg("inference_model_dir", str, None, "Path to save inference model")
custom_g = ArgumentGroup(parser, "Customize options", "")
self.custom_g = custom_g
self.parser = parser
self.arglist = [a.dest for a in self.parser._actions]
self.json_config = None
if json_file != "":
self.load_json(json_file)
def load_json(self, file_path):
"""load json config """
if not os.path.exists(file_path):
raise Warning("the json file %s does not exist." % file_path)
return
try: try:
with open(config_path) as json_file: with open(file_path, "r") as fin:
config_dict = json.load(json_file) self.json_config = json.load(fin)
except Exception: except Exception as e:
raise IOError("Error in parsing bert model config file '%s'" % raise IOError("Error in parsing json config file '%s'" % file_path)
config_path)
else: for name in self.json_config:
return config_dict # use `six.string_types` but not `str` for compatible with python2 and python3
if not isinstance(self.json_config[name], (int, float, bool, six.string_types)):
continue
def __getitem__(self, key): if name in self.arglist:
return self._config_dict[key] self.set_default(name, self.json_config[name])
else:
self.custom_g.add_arg(name,
type(self.json_config[name]),
self.json_config[name],
"customized options")
def print_config(self): def set_default(self, name, value):
""" for arg in self.parser._actions:
Print Config if arg.dest == name:
""" arg.default = value
for arg, value in sorted(six.iteritems(self._config_dict)):
def build(self):
self.args = self.parser.parse_args()
self.arg_config = vars(self.args)
def print_arguments(self):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(self.arg_config)):
print('%s: %s' % (arg, value)) print('%s: %s' % (arg, value))
print('------------------------------------------------') print('------------------------------------------------')
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def __add__(self, new_arg):
assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
assert len(new_arg) >= 3
assert self.args is None
name = new_arg[0]
dtype = new_arg[1]
dvalue = new_arg[2]
desc = new_arg[3] if len(new_arg) == 4 else "Description is not provided."
self.add_arg(name, dtype, dvalue, desc)
return self
def __getattr__(self, name):
if name in self.arg_config:
return self.arg_config[name]
if name in self.json_config:
return self.json_config[name]
raise Warning("The argument %s is not defined." % name)
if __name__ == '__main__':
pd_config = PDConfig('senta_config.json')
pd_config.add_arg("my_age", int, 18, "I am forever 18.")
pd_config.build()
pd_config.print_arguments()
print(pd_config.use_cuda)
print(pd_config.model_type)
# -*- coding: utf_8 -*-
import os
import sys
sys.path.append("../")
import paddle
import paddle.fluid as fluid
import numpy as np
from models.model_check import check_cuda
from config import PDConfig
from run_classifier import create_model
import utils
import reader
def do_save_inference_model(args):
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
infer_pyreader, probs, feed_target_names = create_model(
args,
pyreader_name='infer_reader',
num_labels=args.num_labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True)
exe.run(startup_prog)
assert (args.init_checkpoint)
if args.init_checkpoint:
utils.init_checkpoint(exe, args.init_checkpoint, test_prog)
fluid.io.save_inference_model(
args.inference_model_dir,
feeded_var_names=feed_target_names,
target_vars=[probs],
executor=exe,
main_program=test_prog,
model_filename="model.pdmodel",
params_filename="params.pdparams")
print("save inference model at %s" % (args.inference_model_dir))
def inference(exe, test_program, test_pyreader, fetch_list, infer_phrase):
"""
Inference Function
"""
print("=================")
test_pyreader.start()
while True:
try:
np_props = exe.run(program=test_program, fetch_list=fetch_list, return_numpy=True)
for probs in np_props[0]:
print("%d\t%f\t%f" % (np.argmax(probs), probs[0], probs[1]))
except fluid.core.EOFException:
test_pyreader.reset()
break
def test_inference_model(args):
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
infer_pyreader, probs, feed_target_names = create_model(
args,
pyreader_name='infer_reader',
num_labels=args.num_labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_prog)
processor = reader.SentaProcessor(data_dir=args.data_dir,
vocab_path=args.vocab_path,
random_seed=args.random_seed,
max_seq_len=args.max_seq_len)
num_labels = len(processor.get_labels())
assert (args.inference_model_dir)
infer_program, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=args.inference_model_dir,
executor=exe,
model_filename="model.pdmodel",
params_filename="params.pdparams")
infer_data_generator = processor.data_generator(
batch_size=args.batch_size,
phase="infer",
epoch=1,
shuffle=False)
infer_pyreader.decorate_sample_list_generator(infer_data_generator)
inference(exe, test_prog, infer_pyreader,
[probs.name], "infer")
if __name__ == "__main__":
args = PDConfig('senta_config.json')
args.build()
args.print_arguments()
check_cuda(args.use_cuda)
if args.do_save_inference_model:
do_save_inference_model(args)
else:
test_inference_model(args)
# -*- coding: utf_8 -*-
import os
import sys
sys.path.append("../")
sys.path.append("../models/classification")
import paddle
import paddle.fluid as fluid
import numpy as np
from models.model_check import check_cuda
from config import PDConfig
from run_ernie_classifier import create_model
import utils
import reader
from run_ernie_classifier import ernie_pyreader
from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder
from preprocess.ernie import task_reader
def do_save_inference_model(args):
ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config()
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
infer_pyreader, ernie_inputs, labels = ernie_pyreader(
args,
pyreader_name="infer_reader")
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(args,
embeddings,
labels=labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True)
exe.run(startup_prog)
assert (args.init_checkpoint)
if args.init_checkpoint:
utils.init_checkpoint(exe, args.init_checkpoint, test_prog)
fluid.io.save_inference_model(
args.inference_model_dir,
feeded_var_names=[ernie_inputs["src_ids"].name,
ernie_inputs["sent_ids"].name,
ernie_inputs["pos_ids"].name,
ernie_inputs["input_mask"].name,
ernie_inputs["seq_lens"].name],
target_vars=[probs],
executor=exe,
main_program=test_prog,
model_filename="model.pdmodel",
params_filename="params.pdparams")
print("save inference model at %s" % (args.inference_model_dir))
def inference(exe, test_program, test_pyreader, fetch_list, infer_phrase):
"""
Inference Function
"""
print("=================")
test_pyreader.start()
while True:
try:
np_props = exe.run(program=test_program, fetch_list=fetch_list, return_numpy=True)
for probs in np_props[0]:
print("%d\t%f\t%f" % (np.argmax(probs), probs[0], probs[1]))
except fluid.core.EOFException:
test_pyreader.reset()
break
def test_inference_model(args):
ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config()
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(0)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
reader = task_reader.ClassifyReader(
vocab_path=args.vocab_path,
label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
random_seed=args.random_seed)
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
infer_pyreader, ernie_inputs, labels = ernie_pyreader(
args,
pyreader_name="infer_pyreader")
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(
args,
embeddings,
labels=labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True)
exe.run(startup_prog)
assert (args.inference_model_dir)
infer_data_generator = reader.data_generator(
input_file=args.test_set,
batch_size=args.batch_size,
phase="infer",
epoch=1,
shuffle=False)
infer_program, feed_names, fetch_targets = fluid.io.load_inference_model(
dirname=args.inference_model_dir,
executor=exe,
model_filename="model.pdmodel",
params_filename="params.pdparams")
infer_pyreader.decorate_batch_generator(infer_data_generator)
inference(exe, test_prog, infer_pyreader,
[probs.name], "infer")
if __name__ == "__main__":
args = PDConfig()
args.build()
args.print_arguments()
check_cuda(args.use_cuda)
if args.do_save_inference_model:
do_save_inference_model(args)
else:
test_inference_model(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Senta Reader Senta Reader
""" """
...@@ -25,38 +12,39 @@ from utils import data_reader ...@@ -25,38 +12,39 @@ from utils import data_reader
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
class SentaProcessor(object): class SentaProcessor(object):
""" """
Processor class for data convertors for senta Processor class for data convertors for senta
""" """
def __init__(self, data_dir, vocab_path, random_seed=None): def __init__(self,
data_dir,
vocab_path,
random_seed,
max_seq_len):
self.data_dir = data_dir self.data_dir = data_dir
self.vocab = load_vocab(vocab_path) self.vocab = load_vocab(vocab_path)
self.num_examples = {"train": -1, "dev": -1, "infer": -1} self.num_examples = {"train": -1, "dev": -1, "infer": -1}
np.random.seed(random_seed) np.random.seed(random_seed)
self.max_seq_len = max_seq_len
def get_train_examples(self, data_dir, epoch): def get_train_examples(self, data_dir, epoch, max_seq_len):
""" """
Load training examples Load training examples
""" """
return data_reader((self.data_dir + "/train.tsv"), self.vocab, return data_reader((self.data_dir + "/train.tsv"), self.vocab, self.num_examples, "train", epoch, max_seq_len)
self.num_examples, "train", epoch)
def get_dev_examples(self, data_dir, epoch): def get_dev_examples(self, data_dir, epoch, max_seq_len):
""" """
Load dev examples Load dev examples
""" """
return data_reader((self.data_dir + "/dev.tsv"), self.vocab, return data_reader((self.data_dir + "/dev.tsv"), self.vocab, self.num_examples, "dev", epoch, max_seq_len)
self.num_examples, "dev", epoch)
def get_test_examples(self, data_dir, epoch): def get_test_examples(self, data_dir, epoch, max_seq_len):
""" """
Load test examples Load test examples
""" """
return data_reader((self.data_dir + "/test.tsv"), self.vocab, return data_reader((self.data_dir + "/test.tsv"), self.vocab, self.num_examples, "infer", epoch, max_seq_len)
self.num_examples, "infer", epoch)
def get_labels(self): def get_labels(self):
""" """
...@@ -84,14 +72,12 @@ class SentaProcessor(object): ...@@ -84,14 +72,12 @@ class SentaProcessor(object):
Generate data for train, dev or infer Generate data for train, dev or infer
""" """
if phase == "train": if phase == "train":
return paddle.batch( return paddle.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
self.get_train_examples(self.data_dir, epoch), batch_size) #return self.get_train_examples(self.data_dir, epoch, self.max_seq_len)
elif phase == "dev": elif phase == "dev":
return paddle.batch( return paddle.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
self.get_dev_examples(self.data_dir, epoch), batch_size)
elif phase == "infer": elif phase == "infer":
return paddle.batch( return paddle.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
self.get_test_examples(self.data_dir, epoch), batch_size)
else: else:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'infer'].") "Unknown phase, which should be in ['train', 'dev', 'infer'].")
#! /bin/bash #! /bin/bash
export FLAGS_enable_parallel_graph=1 export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=12
export FLAGS_fraction_of_gpu_memory_to_use=0.95 export FLAGS_fraction_of_gpu_memory_to_use=0.95
export CPU_NUM=1 export CPU_NUM=1
...@@ -16,9 +16,9 @@ train() { ...@@ -16,9 +16,9 @@ train() {
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
--use_cuda true \ --use_cuda true \
--do_train true \ --do_train true \
--do_val true \ --do_val false \
--do_infer false \ --do_infer false \
--batch_size 16 \ --batch_size 8 \
--data_dir ${DATA_PATH} \ --data_dir ${DATA_PATH} \
--vocab_path ${DATA_PATH}/word_dict.txt \ --vocab_path ${DATA_PATH}/word_dict.txt \
--checkpoints ${CKPT_PATH} \ --checkpoints ${CKPT_PATH} \
...@@ -59,6 +59,15 @@ infer() { ...@@ -59,6 +59,15 @@ infer() {
--senta_config_path ./senta_config.json --senta_config_path ./senta_config.json
} }
# run_save_inference_model
save_inference_model() {
python -u inference_model.py \
--use_cuda false \
--do_save_inference_model true \
--init_checkpoint ${MODEL_PATH} \
--inference_model_dir ./inference_model
}
main() { main() {
local cmd=${1:-help} local cmd=${1:-help}
case "${cmd}" in case "${cmd}" in
...@@ -71,13 +80,16 @@ main() { ...@@ -71,13 +80,16 @@ main() {
infer) infer)
infer "$@"; infer "$@";
;; ;;
save_inference_model)
save_inference_model "$@";
;;
help) help)
echo "Usage: ${BASH_SOURCE} {train|eval|infer}"; echo "Usage: ${BASH_SOURCE} {train|eval|infer|save_inference_model}";
return 0; return 0;
;; ;;
*) *)
echo "Unsupport commend [${cmd}]"; echo "Unsupport commend [${cmd}]";
echo "Usage: ${BASH_SOURCE} {train|eval|infer}"; echo "Usage: ${BASH_SOURCE} {train|eval|infer|save_inference_model}";
return 1; return 1;
;; ;;
esac esac
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Sentiment Classification Task Sentiment Classification Task
""" """
...@@ -34,93 +21,56 @@ from nets import cnn_net ...@@ -34,93 +21,56 @@ from nets import cnn_net
from nets import bilstm_net from nets import bilstm_net
from nets import gru_net from nets import gru_net
from models.model_check import check_cuda from models.model_check import check_cuda
from config import PDConfig
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader import reader
from config import SentaConfig
from utils import ArgumentGroup, print_arguments
from utils import init_checkpoint from utils import init_checkpoint
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.")
log_g = ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, None, "Path to training data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 256, "Total examples' number in batch for training.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("task_name", str, None,
"The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation.")
run_type_g.add_arg("do_infer", bool, True, "Whether to perform inference.")
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
args = parser.parse_args()
# yapf: enable.
def create_model(args, def create_model(args,
pyreader_name, pyreader_name,
senta_config,
num_labels, num_labels,
is_inference=False): is_prediction=False):
""" """
Create Model for sentiment classification Create Model for sentiment classification
""" """
pyreader = fluid.layers.py_reader( data = fluid.layers.data(
capacity=16, name="src_ids", shape=[-1, args.max_seq_len, 1], dtype='int64')
shapes=([-1, 1], [-1, 1]), label = fluid.layers.data(
dtypes=('int64', 'int64'), name="label", shape=[-1, 1], dtype="int64")
lod_levels=(1, 0), seq_len = fluid.layers.data(
name=pyreader_name, name="seq_len", shape=[-1, 1], dtype="int64")
use_double_buffer=False)
data_reader = fluid.io.PyReader(feed_list=[data, label, seq_len],
if senta_config['model_type'] == "bilstm_net": capacity=4, iterable=False)
if args.model_type == "bilstm_net":
network = bilstm_net network = bilstm_net
elif senta_config['model_type'] == "bow_net": elif args.model_type == "bow_net":
network = bow_net network = bow_net
elif senta_config['model_type'] == "cnn_net": elif args.model_type == "cnn_net":
network = cnn_net network = cnn_net
elif senta_config['model_type'] == "lstm_net": elif args.model_type == "lstm_net":
network = lstm_net network = lstm_net
elif senta_config['model_type'] == "gru_net": elif args.model_type == "gru_net":
network = gru_net network = gru_net
else: else:
raise ValueError("Unknown network type!") raise ValueError("Unknown network type!")
if is_inference: if is_prediction:
data, label = fluid.layers.read_file(pyreader) probs = network(data, seq_len, None, args.vocab_size, is_prediction=is_prediction)
probs = network(data, None, senta_config["vocab_size"], is_infer=is_inference)
print("create inference model...") print("create inference model...")
return pyreader, probs return data_reader, probs, [data.name, seq_len.name]
data, label = fluid.layers.read_file(pyreader) ce_loss, probs = network(data, seq_len, label, args.vocab_size, is_prediction=is_prediction)
ce_loss, probs = network(data, label, senta_config["vocab_size"], is_infer=is_inference)
loss = fluid.layers.mean(x=ce_loss) loss = fluid.layers.mean(x=ce_loss)
num_seqs = fluid.layers.create_tensor(dtype='int64') num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs) accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs)
return pyreader, loss, accuracy, num_seqs return data_reader, loss, accuracy, num_seqs
...@@ -132,6 +82,7 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): ...@@ -132,6 +82,7 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
while True: while True:
#print("===============")
try: try:
np_loss, np_acc, np_num_seqs = exe.run(program=test_program, np_loss, np_acc, np_num_seqs = exe.run(program=test_program,
fetch_list=fetch_list, fetch_list=fetch_list,
...@@ -174,8 +125,6 @@ def main(args): ...@@ -174,8 +125,6 @@ def main(args):
""" """
Main Function Main Function
""" """
senta_config = SentaConfig(args.senta_config_path)
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
...@@ -187,10 +136,10 @@ def main(args): ...@@ -187,10 +136,10 @@ def main(args):
task_name = args.task_name.lower() task_name = args.task_name.lower()
processor = reader.SentaProcessor(data_dir=args.data_dir, processor = reader.SentaProcessor(data_dir=args.data_dir,
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
random_seed=args.random_seed) random_seed=args.random_seed,
max_seq_len=args.max_seq_len)
num_labels = len(processor.get_labels()) num_labels = len(processor.get_labels())
if not (args.do_train or args.do_val or args.do_infer): if not (args.do_train or args.do_val or args.do_infer):
raise ValueError("For args `do_train`, `do_val` and `do_infer`, at " raise ValueError("For args `do_train`, `do_val` and `do_infer`, at "
"least one of them must be True.") "least one of them must be True.")
...@@ -220,12 +169,11 @@ def main(args): ...@@ -220,12 +169,11 @@ def main(args):
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, accuracy, num_seqs = create_model( train_reader, loss, accuracy, num_seqs = create_model(
args, args,
pyreader_name='train_reader', pyreader_name='train_reader',
senta_config=senta_config,
num_labels=num_labels, num_labels=num_labels,
is_inference=False) is_prediction=False)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr)
sgd_optimizer.minimize(loss) sgd_optimizer.minimize(loss)
...@@ -237,28 +185,36 @@ def main(args): ...@@ -237,28 +185,36 @@ def main(args):
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val: if args.do_val:
test_data_generator = processor.data_generator(
batch_size=args.batch_size,
phase='dev',
epoch=1,
shuffle=False)
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_pyreader, loss, accuracy, num_seqs = create_model( test_reader, loss, accuracy, num_seqs = create_model(
args, args,
pyreader_name='test_reader', pyreader_name='test_reader',
senta_config=senta_config,
num_labels=num_labels, num_labels=num_labels,
is_inference=False) is_prediction=False)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
if args.do_infer: if args.do_infer:
infer_data_generator = processor.data_generator(
batch_size=args.batch_size,
phase='infer',
epoch=1,
shuffle=False)
infer_prog = fluid.Program() infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog): with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
infer_pyreader, prop = create_model( infer_reader, prop, _ = create_model(
args, args,
pyreader_name='infer_reader', pyreader_name='infer_reader',
senta_config=senta_config,
num_labels=num_labels, num_labels=num_labels,
is_inference=True) is_prediction=True)
infer_prog = infer_prog.clone(for_test=True) infer_prog = infer_prog.clone(for_test=True)
exe.run(startup_prog) exe.run(startup_prog)
...@@ -281,14 +237,18 @@ def main(args): ...@@ -281,14 +237,18 @@ def main(args):
if args.do_train: if args.do_train:
train_exe = exe train_exe = exe
train_pyreader.decorate_paddle_reader(train_data_generator) train_reader.decorate_sample_list_generator(train_data_generator)
else: else:
train_exe = None train_exe = None
if args.do_val or args.do_infer: if args.do_val:
test_exe = exe
test_reader.decorate_sample_list_generator(test_data_generator)
if args.do_infer:
test_exe = exe test_exe = exe
infer_reader.decorate_sample_list_generator(infer_data_generator)
if args.do_train: if args.do_train:
train_pyreader.start() train_reader.start()
steps = 0 steps = 0
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
...@@ -335,55 +295,32 @@ def main(args): ...@@ -335,55 +295,32 @@ def main(args):
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
print("do evalatation") print("do evalatation")
test_pyreader.decorate_paddle_reader( evaluate(exe, test_prog, test_reader,
processor.data_generator(
batch_size=args.batch_size,
phase='dev',
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"dev") "dev")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
train_pyreader.reset() train_reader.reset()
break break
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_paddle_reader(
processor.data_generator(
batch_size=args.batch_size, phase='dev', epoch=1,
shuffle=False))
print("Final validation result:") print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_reader,
[loss.name, accuracy.name, num_seqs.name], "dev") [loss.name, accuracy.name, num_seqs.name], "dev")
test_pyreader.decorate_paddle_reader(
processor.data_generator(
batch_size=args.batch_size, phase='infer', epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], "infer")
# final eval on test set # final eval on test set
if args.do_infer: if args.do_infer:
infer_pyreader.decorate_paddle_reader(
processor.data_generator(
batch_size=args.batch_size,
phase='infer',
epoch=1,
shuffle=False))
print("Final test result:") print("Final test result:")
inference(exe, infer_prog, infer_pyreader, inference(exe, infer_prog, infer_reader,
[prop.name], "infer") [prop.name], "infer")
if __name__ == "__main__": if __name__ == "__main__":
print_arguments(args) args = PDConfig('senta_config.json')
args.build()
args.print_arguments()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
main(args) main(args)
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
export FLAGS_fraction_of_gpu_memory_to_use=0.95 export FLAGS_fraction_of_gpu_memory_to_use=0.95
export FLAGS_enable_parallel_graph=1 export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1 export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3 export CUDA_VISIBLE_DEVICES=12
export CPU_NUM=1 export CPU_NUM=1
ERNIE_PRETRAIN=./senta_model/ernie_pretrain_model/ ERNIE_PRETRAIN=./ernie_pretrain_model/
DATA_PATH=./senta_data DATA_PATH=./senta_data
MODEL_SAVE_PATH=./save_models MODEL_SAVE_PATH=./save_models
...@@ -17,7 +17,7 @@ train() { ...@@ -17,7 +17,7 @@ train() {
--do_val true \ --do_val true \
--do_infer false \ --do_infer false \
--use_paddle_hub false \ --use_paddle_hub false \
--batch_size 24 \ --batch_size 4 \
--init_checkpoint $ERNIE_PRETRAIN/params \ --init_checkpoint $ERNIE_PRETRAIN/params \
--train_set $DATA_PATH/train.tsv \ --train_set $DATA_PATH/train.tsv \
--dev_set $DATA_PATH/dev.tsv \ --dev_set $DATA_PATH/dev.tsv \
...@@ -25,8 +25,8 @@ train() { ...@@ -25,8 +25,8 @@ train() {
--vocab_path $ERNIE_PRETRAIN/vocab.txt \ --vocab_path $ERNIE_PRETRAIN/vocab.txt \
--checkpoints $MODEL_SAVE_PATH \ --checkpoints $MODEL_SAVE_PATH \
--save_steps 5000 \ --save_steps 5000 \
--validation_steps 100 \ --validation_steps 5000 \
--epoch 10 \ --epoch 2 \
--max_seq_len 256 \ --max_seq_len 256 \
--ernie_config_path $ERNIE_PRETRAIN/ernie_config.json \ --ernie_config_path $ERNIE_PRETRAIN/ernie_config.json \
--model_type "ernie_base" \ --model_type "ernie_base" \
...@@ -45,8 +45,8 @@ evaluate() { ...@@ -45,8 +45,8 @@ evaluate() {
--do_val true \ --do_val true \
--do_infer false \ --do_infer false \
--use_paddle_hub false \ --use_paddle_hub false \
--batch_size 24 \ --batch_size 4 \
--init_checkpoint ./save_models/step_5000/ \ --init_checkpoint ./save_models/step_4801/ \
--dev_set $DATA_PATH/dev.tsv \ --dev_set $DATA_PATH/dev.tsv \
--vocab_path $ERNIE_PRETRAIN/vocab.txt \ --vocab_path $ERNIE_PRETRAIN/vocab.txt \
--max_seq_len 256 \ --max_seq_len 256 \
...@@ -61,8 +61,8 @@ evaluate() { ...@@ -61,8 +61,8 @@ evaluate() {
--do_val true \ --do_val true \
--do_infer false \ --do_infer false \
--use_paddle_hub false \ --use_paddle_hub false \
--batch_size 24 \ --batch_size 4 \
--init_checkpoint ./save_models/step_5000/ \ --init_checkpoint ./save_models/step_4801/ \
--dev_set $DATA_PATH/test.tsv \ --dev_set $DATA_PATH/test.tsv \
--vocab_path $ERNIE_PRETRAIN/vocab.txt \ --vocab_path $ERNIE_PRETRAIN/vocab.txt \
--max_seq_len 256 \ --max_seq_len 256 \
...@@ -80,8 +80,8 @@ infer() { ...@@ -80,8 +80,8 @@ infer() {
--do_val false \ --do_val false \
--do_infer true \ --do_infer true \
--use_paddle_hub false \ --use_paddle_hub false \
--batch_size 24 \ --batch_size 4 \
--init_checkpoint ./save_models/step_5000 \ --init_checkpoint ./save_models/step_4801 \
--test_set $DATA_PATH/test.tsv \ --test_set $DATA_PATH/test.tsv \
--vocab_path $ERNIE_PRETRAIN/vocab.txt \ --vocab_path $ERNIE_PRETRAIN/vocab.txt \
--max_seq_len 256 \ --max_seq_len 256 \
...@@ -90,6 +90,20 @@ infer() { ...@@ -90,6 +90,20 @@ infer() {
--num_labels 2 --num_labels 2
} }
# run_save_inference_model
save_inference_model() {
python -u inference_model_ernie.py \
--use_cuda true \
--do_save_inference_model true \
--init_checkpoint ./save_models/step_4801/ \
--inference_model_dir ./inference_model \
--ernie_config_path $ERNIE_PRETRAIN/ernie_config.json \
--model_type "ernie_base" \
--vocab_path $ERNIE_PRETRAIN/vocab.txt \
--test_set ${DATA_PATH}/test.tsv \
--batch_size 4
}
main() { main() {
local cmd=${1:-help} local cmd=${1:-help}
case "${cmd}" in case "${cmd}" in
...@@ -102,13 +116,16 @@ main() { ...@@ -102,13 +116,16 @@ main() {
infer) infer)
infer "$@"; infer "$@";
;; ;;
save_inference_model)
save_inference_model "$@";
;;
help) help)
echo "Usage: ${BASH_SOURCE} {train|eval|infer}"; echo "Usage: ${BASH_SOURCE} {train|eval|infer|save_inference_model}";
return 0; return 0;
;; ;;
*) *)
echo "Unsupport commend [${cmd}]"; echo "Unsupport commend [${cmd}]";
echo "Usage: ${BASH_SOURCE} {train|eval|infer}"; echo "Usage: ${BASH_SOURCE} {train|eval|infer|save_inference_model}";
return 1; return 1;
;; ;;
esac esac
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Sentiment Classification Task Sentiment Classification Task
""" """
...@@ -43,54 +30,35 @@ from nets import ernie_bilstm_net ...@@ -43,54 +30,35 @@ from nets import ernie_bilstm_net
from preprocess.ernie import task_reader from preprocess.ernie import task_reader
from models.representation.ernie import ErnieConfig from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub
from models.representation.ernie import ernie_pyreader #from models.representation.ernie import ernie_pyreader
from models.model_check import check_cuda from models.model_check import check_cuda
from utils import ArgumentGroup from config import PDConfig
from utils import print_arguments
from utils import init_checkpoint from utils import init_checkpoint
# yapf: disable def ernie_pyreader(args, pyreader_name):
parser = argparse.ArgumentParser(__doc__) src_ids = fluid.layers.data(
model_g = ArgumentGroup(parser, "model", "model configuration and paths.") name="src_ids", shape=[-1, args.max_seq_len, 1], dtype="int64")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for ernie model config.") sent_ids = fluid.layers.data(
model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta model config.") name="sent_ids", shape=[-1, args.max_seq_len, 1], dtype="int64")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") pos_ids = fluid.layers.data(
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints") name="pos_ids", shape=[-1, args.max_seq_len, 1], dtype="int64")
model_g.add_arg("model_type", str, "ernie_base", "Type of current ernie model") input_mask = fluid.layers.data(
model_g.add_arg("use_paddle_hub", bool, False, "Whether to load ERNIE using PaddleHub") name="input_mask", shape=[-1, args.max_seq_len, 1], dtype="float32")
labels = fluid.layers.data(
train_g = ArgumentGroup(parser, "training", "training options.") name="labels", shape=[-1, 1], dtype="int64")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.") seq_lens = fluid.layers.data(
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.") name="seq_lens", shape=[-1, 1], dtype="int64")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.") pyreader = fluid.io.PyReader(feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens],
capacity=4, iterable=False)
log_g = ArgumentGroup(parser, "logging", "logging related") ernie_inputs = {
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") "src_ids": src_ids,
log_g.add_arg("verbose", bool, False, "Whether to output verbose log") "sent_ids": sent_ids,
"pos_ids": pos_ids,
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") "input_mask": input_mask,
data_g.add_arg("data_dir", str, None, "Path to training data.") "seq_lens": seq_lens}
data_g.add_arg("vocab_path", str, None, "Vocabulary path.") return pyreader, ernie_inputs, labels
data_g.add_arg("batch_size", int, 256, "Total examples' number in batch for training.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
data_g.add_arg("num_labels", int, 2, "label number")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("train_set", str, None, "Path to training data.")
data_g.add_arg("test_set", str, None, "Path to test data.")
data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("label_map_config", str, None, "label_map_path.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation.")
run_type_g.add_arg("do_infer", bool, True, "Whether to perform inference.")
args = parser.parse_args()
# yapf: enable.
def create_model(args, def create_model(args,
embeddings, embeddings,
...@@ -174,7 +142,6 @@ def main(args): ...@@ -174,7 +142,6 @@ def main(args):
""" """
Main Function Main Function
""" """
args = parser.parse_args()
ernie_config = ErnieConfig(args.ernie_config_path) ernie_config = ErnieConfig(args.ernie_config_path)
ernie_config.print_config() ernie_config.print_config()
...@@ -224,7 +191,7 @@ def main(args): ...@@ -224,7 +191,7 @@ def main(args):
# create ernie_pyreader # create ernie_pyreader
train_pyreader, ernie_inputs, labels = ernie_pyreader( train_pyreader, ernie_inputs, labels = ernie_pyreader(
args, args,
pyreader_name='train_reader') pyreader_name='train_pyreader')
# get ernie_embeddings # get ernie_embeddings
if args.use_paddle_hub: if args.use_paddle_hub:
...@@ -239,10 +206,6 @@ def main(args): ...@@ -239,10 +206,6 @@ def main(args):
labels=labels, labels=labels,
is_prediction=False) is_prediction=False)
"""
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr)
sgd_optimizer.minimize(loss)
"""
optimizer = fluid.optimizer.Adam(learning_rate=args.lr) optimizer = fluid.optimizer.Adam(learning_rate=args.lr)
optimizer.minimize(loss) optimizer.minimize(loss)
...@@ -253,6 +216,12 @@ def main(args): ...@@ -253,6 +216,12 @@ def main(args):
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val: if args.do_val:
test_data_generator = reader.data_generator(
input_file=args.dev_set,
batch_size=args.batch_size,
phase='dev',
epoch=1,
shuffle=False)
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -277,12 +246,18 @@ def main(args): ...@@ -277,12 +246,18 @@ def main(args):
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
if args.do_infer: if args.do_infer:
infer_data_generator = reader.data_generator(
input_file=args.test_set,
batch_size=args.batch_size,
phase='infer',
epoch=1,
shuffle=False)
infer_prog = fluid.Program() infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog): with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
infer_pyreader, ernie_inputs, labels = ernie_pyreader( infer_pyreader, ernie_inputs, labels = ernie_pyreader(
args, args,
pyreader_name="infer_reader") pyreader_name="infer_pyreader")
# get ernie_embeddings # get ernie_embeddings
if args.use_paddle_hub: if args.use_paddle_hub:
...@@ -323,20 +298,16 @@ def main(args): ...@@ -323,20 +298,16 @@ def main(args):
main_program=infer_prog) main_program=infer_prog)
if args.do_train: if args.do_train:
exec_strategy = fluid.ExecutionStrategy() train_exe = exe
exec_strategy.num_iteration_per_drop_scope = 1 train_pyreader.decorate_batch_generator(train_data_generator)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
main_program=train_program)
train_pyreader.decorate_tensor_provider(train_data_generator)
else: else:
train_exe = None train_exe = None
if args.do_val or args.do_infer: if args.do_val:
test_exe = exe test_exe = exe
test_pyreader.decorate_batch_generator(test_data_generator)
if args.do_infer:
test_exe = exe
infer_pyreader.decorate_batch_generator(infer_data_generator)
if args.do_train: if args.do_train:
train_pyreader.start() train_pyreader.start()
...@@ -351,7 +322,7 @@ def main(args): ...@@ -351,7 +322,7 @@ def main(args):
else: else:
fetch_list = [] fetch_list = []
outputs = train_exe.run(fetch_list=fetch_list, return_numpy=False) outputs = train_exe.run(program=train_program, fetch_list=fetch_list, return_numpy=False)
if steps % args.skip_steps == 0: if steps % args.skip_steps == 0:
np_loss, np_acc, np_num_seqs = outputs np_loss, np_acc, np_num_seqs = outputs
np_loss = np.array(np_loss) np_loss = np.array(np_loss)
...@@ -383,30 +354,10 @@ def main(args): ...@@ -383,30 +354,10 @@ def main(args):
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider(
reader.data_generator(
input_file=args.dev_set,
batch_size=args.batch_size,
phase='dev',
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"dev") "dev")
test_pyreader.decorate_tensor_provider(
reader.data_generator(
input_file=args.test_set,
batch_size=args.batch_size,
phase='infer',
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name],
"infer")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
...@@ -415,29 +366,19 @@ def main(args): ...@@ -415,29 +366,19 @@ def main(args):
# final eval on dev set # final eval on dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider(
reader.data_generator(
input_file=args.dev_set,
batch_size=args.batch_size, phase='dev', epoch=1,
shuffle=False))
print("Final validation result:") print("Final validation result:")
evaluate(exe, test_prog, test_pyreader, evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], "dev") [loss.name, accuracy.name, num_seqs.name], "dev")
# final eval on test set # final eval on test set
if args.do_infer: if args.do_infer:
infer_pyreader.decorate_tensor_provider(
reader.data_generator(
input_file=args.test_set,
batch_size=args.batch_size,
phase='infer',
epoch=1,
shuffle=False))
print("Final test result:") print("Final test result:")
infer(exe, infer_prog, infer_pyreader, infer(exe, infer_prog, infer_pyreader,
[probs.name], "infer") [probs.name], "infer")
if __name__ == "__main__": if __name__ == "__main__":
print_arguments(args) args = PDConfig()
args.build()
args.print_arguments()
check_cuda(args.use_cuda) check_cuda(args.use_cuda)
main(args) main(args)
{ {
"model_type": "bilstm_net", "model_type": "bilstm_net",
"vocab_size": 33256 "vocab_size": 33256,
"max_seq_len": 256
} }
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Arguments for configuration Arguments for configuration
""" """
...@@ -44,7 +31,6 @@ class ArgumentGroup(object): ...@@ -44,7 +31,6 @@ class ArgumentGroup(object):
""" """
Argument Class Argument Class
""" """
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
...@@ -93,12 +79,13 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -93,12 +79,13 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
def data_reader(file_path, word_dict, num_examples, phrase, epoch): def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
""" """
Convert word sequence into slot Convert word sequence into slot
""" """
unk_id = len(word_dict) unk_id = len(word_dict)
pad_id = 0
all_data = [] all_data = []
with io.open(file_path, "r", encoding='utf8') as fin: with io.open(file_path, "r", encoding='utf8') as fin:
for line in fin: for line in fin:
...@@ -109,28 +96,31 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch): ...@@ -109,28 +96,31 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch):
sys.stderr.write("[NOTICE] Error Format Line!") sys.stderr.write("[NOTICE] Error Format Line!")
continue continue
label = int(cols[1]) label = int(cols[1])
wids = [ wids = [word_dict[x] if x in word_dict else unk_id
word_dict[x] if x in word_dict else unk_id for x in cols[0].split(" ")]
for x in cols[0].split(" ") seq_len = len(wids)
] if seq_len < max_seq_len:
all_data.append((wids, label)) for i in range(max_seq_len - seq_len):
wids.append(pad_id)
else:
wids = wids[:max_seq_len]
seq_len = max_seq_len
all_data.append((wids, label, seq_len))
if phrase == "train": if phrase == "train":
random.shuffle(all_data) random.shuffle(all_data)
num_examples[phrase] = len(all_data) num_examples[phrase] = len(all_data)
def reader(): def reader():
""" """
Reader Function Reader Function
""" """
for epoch_index in range(epoch): for epoch_index in range(epoch):
for doc, label in all_data: for doc, label, seq_len in all_data:
yield doc, label yield doc, label, seq_len
return reader return reader
def load_vocab(file_path): def load_vocab(file_path):
""" """
load the given vocabulary load the given vocabulary
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册