提交 550947a7 编写于 作者: W wuzewu 提交者: bbking

Add an option to control whether to load ERNIE using PaddleHub (#2401)

* Add an option to control whether to load ERNIE using PaddleHub

* Update README
上级 81d868fa
......@@ -19,10 +19,9 @@
## 快速开始
本项目依赖于 Python2.7、Paddlepaddle Fluid 1.4.0以及PaddleHub 0.5.0,请确保相关依赖都已安装正确
本项目依赖于 Paddlepaddle 1.3.2 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
[PaddlePaddle安装指南](http://www.paddlepaddle.org/#quick-start)
[PaddleHub安装指南](https://github.com/PaddlePaddle/PaddleHub)
python版本依赖python 2.7
#### 安装代码
......@@ -169,6 +168,19 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut
--init_checkpoint ./models/ernie_finetune/params
```
我们也提供了使用PaddleHub加载ERNIE模型的选项,PaddleHub是PaddlePaddle的预训练模型管理工具,可以一行代码完成预训练模型的加载,简化预训练模型的使用和迁移学习。更多相关的介绍,可以查看[PaddleHub](https://github.com/PaddlePaddle/PaddleHub)
如果想使用该功能,需要修改run_ernie.sh中的配置如下:
```shell
# 在train()函数中,修改--use_paddle_hub选项
--use_paddle_hub true
```
注意:使用该选项需要先安装PaddleHub,安装命令如下
```shell
$ pip install paddlehub
```
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
......@@ -12,6 +12,7 @@ train() {
--verbose true \
--do_train true \
--do_val true \
--use_paddle_hub false \
--batch_size 32 \
--init_checkpoint ${MODEL_PATH}/params \
--train_set ${TASK_DATA_PATH}/train.tsv \
......@@ -35,6 +36,7 @@ evaluate() {
--use_cuda true \
--verbose true \
--do_val true \
--use_paddle_hub false \
--batch_size 32 \
--init_checkpoint ${MODEL_PATH}/params \
--test_set ${TASK_DATA_PATH}/test.tsv \
......@@ -50,6 +52,7 @@ infer() {
--use_cuda true \
--verbose true \
--do_infer true \
--use_paddle_hub false \
--batch_size 32 \
--init_checkpoint ${MODEL_PATH}/params \
--infer_set ${TASK_DATA_PATH}/infer.tsv \
......
......@@ -28,6 +28,7 @@ model_g.add_arg("ernie_config_path", str, None, "Path to the json file for ernie
model_g.add_arg("senta_config_path", str, None, "Path to the json file for senta model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, "checkpoints", "Path to save checkpoints")
model_g.add_arg("use_paddle_hub", bool, False, "Whether to load ERNIE using PaddleHub")
train_g = utils.ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
......@@ -201,7 +202,10 @@ def main(args):
pyreader_name='train_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
if args.use_paddle_hub:
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -233,7 +237,10 @@ def main(args):
pyreader_name='eval_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
if args.use_paddle_hub:
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -253,7 +260,10 @@ def main(args):
pyreader_name='infer_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
if args.use_paddle_hub:
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(args,
embeddings,
......
......@@ -57,8 +57,7 @@ def ernie_encoder_with_paddle_hub(ernie_inputs, max_seq_len):
pre_program=main_program,
next_program=program,
input_dict=input_dict,
inplace=True,
need_log=False)
inplace=True)
enc_out = outputs["sequence_output"]
unpad_enc_out = fluid.layers.sequence_unpad(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册