提交 4c6e4456 编写于 作者: W wuzewu 提交者: LiuHao

Update the sentiment_classification demo and use the PaddleHub to get the ERNIE model (#2388)

上级 db9bb56d
......@@ -82,6 +82,18 @@ senta_config.json中需要修改如下:
--model_type "ernie_bilstm"
```
我们也提供了使用PaddleHub加载ERNIE模型的选项,PaddleHub是PaddlePaddle的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看[PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
如果想使用该功能,需要修改run_ernie.sh中的配置如下:
```shell
# 在eval()函数中,修改如下参数:
--use_paddle_hub true
```
注意:使用该选项需要先安装PaddleHub,安装命令如下
```shell
$ pip install paddlehub
```
#### 模型训练
基于示例的数据集,可以运行下面的命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证
......
......@@ -16,6 +16,7 @@ train() {
--do_train true \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint $ERNIE_PRETRAIN/params \
--train_set $DATA_PATH/train.tsv \
......@@ -43,6 +44,7 @@ evaluate() {
--do_train false \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000/ \
--dev_set $DATA_PATH/dev.tsv \
......@@ -58,6 +60,7 @@ evaluate() {
--do_train false \
--do_val true \
--do_infer false \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000/ \
--dev_set $DATA_PATH/test.tsv \
......@@ -76,6 +79,7 @@ infer() {
--do_train false \
--do_val false \
--do_infer true \
--use_paddle_hub false \
--batch_size 24 \
--init_checkpoint ./save_models/step_5000 \
--test_set $DATA_PATH/test.tsv \
......
......@@ -29,7 +29,7 @@ from nets import ernie_base_net
from nets import ernie_bilstm_net
from preprocess.ernie import task_reader
from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder
from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub
from models.representation.ernie import ernie_pyreader
from utils import ArgumentGroup
from utils import print_arguments
......@@ -43,6 +43,7 @@ model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints")
model_g.add_arg("model_type", str, "ernie_base", "Type of current ernie model")
model_g.add_arg("use_paddle_hub", bool, False, "Whether to load ERNIE using PaddleHub")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
......@@ -212,6 +213,9 @@ def main(args):
pyreader_name='train_reader')
# get ernie_embeddings
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
......@@ -244,6 +248,9 @@ def main(args):
pyreader_name='eval_reader')
# get ernie_embeddings
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
......@@ -264,6 +271,9 @@ def main(args):
pyreader_name="infer_reader")
# get ernie_embeddings
if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(args,
......@@ -280,15 +290,23 @@ def main(args):
init_checkpoint(
exe,
args.init_checkpoint,
main_program=startup_prog)
elif args.do_val or args.do_infer:
main_program=train_program)
elif args.do_val:
if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if"
"only doing validation or testing!")
init_checkpoint(
exe,
args.init_checkpoint,
main_program=test_prog)
elif args.do_infer:
if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if"
"only doing validation or testing!")
init_checkpoint(
exe,
args.init_checkpoint,
main_program=startup_prog)
main_program=infer_prog)
if args.do_train:
exec_strategy = fluid.ExecutionStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册