提交 4c6e4456 编写于 作者: W wuzewu 提交者: LiuHao

Update the sentiment_classification demo and use the PaddleHub to get the ERNIE model (#2388)

上级 db9bb56d
......@@ -3,8 +3,8 @@
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。可通过[AI开放平台-情感倾向分析](http://ai.baidu.com/tech/nlp_apply/sentiment_classify) 线上体验。
情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测;此外,我们还开源了百度基于海量数据训练好的模型,该模型在ChnSentiCorp数据集上fine-tune之后(基于开源模型进行Finetune的方法请见下面章节),可以得到更好的效果。具体数据如下所示:
| 模型 | dev | test | 模型(finetune) |dev | test |
| 模型 | dev | test | 模型(finetune) |dev | test |
| :------| :------ | :------ | :------ |:------ | :------
| BOW | 89.8% | 90.0% | BOW |91.3% | 90.6% |
| CNN | 90.6% | 89.9% | CNN |92.4% | 91.8% |
......@@ -27,7 +27,7 @@ python版本依赖python 2.7
克隆数据集代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/sentiment_classification
cd models/PaddleNLP/sentiment_classification
```
#### 数据准备
......@@ -82,6 +82,18 @@ senta_config.json中需要修改如下:
--model_type "ernie_bilstm"
```
我们也提供了使用PaddleHub加载ERNIE模型的选项,PaddleHub是PaddlePaddle的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看[PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
如果想使用该功能,需要修改run_ernie.sh中的配置如下:
```shell
# 在eval()函数中,修改如下参数:
--use_paddle_hub true
```
注意:使用该选项需要先安装PaddleHub,安装命令如下
```shell
$ pip install paddlehub
```
#### 模型训练
基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
......@@ -154,7 +166,7 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut
可以根据自己的需求,组建自定义的模型,具体方法如下所示:
1. 定义自己的网络结构
1. 定义自己的网络结构
用户可以在 ```models/classification/nets.py``` 中,定义自己的模型,只需要增加新的函数即可。假设用户自定义的函数名为```user_net```
2. 更改模型配置
```senta_config.json``` 中需要将 ```model_type``` 改为用户自定义的 ```user_net```
......
......@@ -16,6 +16,7 @@ train() {
--do_train true \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint $ERNIE_PRETRAIN/params \
--train_set $DATA_PATH/train.tsv \
......@@ -43,6 +44,7 @@ evaluate() {
--do_train false \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000/ \
--dev_set $DATA_PATH/dev.tsv \
......@@ -58,6 +60,7 @@ evaluate() {
--do_train false \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000/ \
--dev_set $DATA_PATH/test.tsv \
......@@ -76,6 +79,7 @@ infer() {
--do_train false \
--do_val false \
--do_infer true \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000 \
--test_set $DATA_PATH/test.tsv \
......
......@@ -29,7 +29,7 @@ from nets import ernie_base_net
from nets import ernie_bilstm_net
from preprocess.ernie import task_reader
from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder
from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub
from models.representation.ernie import ernie_pyreader
from utils import ArgumentGroup
from utils import print_arguments
......@@ -43,6 +43,7 @@ model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints")
model_g.add_arg("model_type", str, "ernie_base", "Type of current ernie model")
model_g.add_arg("use_paddle_hub", bool, False, "Whether to load ERNIE using PaddleHub")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
......@@ -212,7 +213,10 @@ def main(args):
pyreader_name='train_reader')
# get ernie_embeddings
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -244,7 +248,10 @@ def main(args):
pyreader_name='eval_reader')
# get ernie_embeddings
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -264,7 +271,10 @@ def main(args):
pyreader_name="infer_reader")
# get ernie_embeddings
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(args,
embeddings,
......@@ -280,26 +290,34 @@ def main(args):
init_checkpoint(
exe,
args.init_checkpoint,
main_program=startup_prog)
elif args.do_val or args.do_infer:
main_program=train_program)
elif args.do_val:
if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if"
"only doing validation or testing!")
init_checkpoint(
exe,
args.init_checkpoint,
main_program=startup_prog)
main_program=test_prog)
elif args.do_infer:
if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if"
"only doing validation or testing!")
init_checkpoint(
exe,
args.init_checkpoint,
main_program=infer_prog)
if args.do_train:
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
main_program=train_program)
train_pyreader.decorate_tensor_provider(train_data_generator)
else:
train_exe = None
......@@ -362,7 +380,7 @@ def main(args):
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name],
"dev")
test_pyreader.decorate_tensor_provider(
reader.data_generator(
input_file=args.test_set,
......@@ -370,7 +388,7 @@ def main(args):
phase='infer',
epoch=1,
shuffle=False))
evaluate(exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name],
"infer")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册