提交 366c712b 编写于 作者: W wuzewu 提交者: bbking

Update the emotion demo and use the PaddleHub to get the ERNIE model (#2384)

上级 5c142ae7
......@@ -19,7 +19,10 @@
## 快速开始
本项目依赖于 Python2.7 和 Paddlepaddle Fluid 1.3.2,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 Python2.7、Paddlepaddle Fluid 1.4.0以及PaddleHub 0.5.0,请确保相关依赖都已安装正确
[PaddlePaddle安装指南](http://www.paddlepaddle.org/#quick-start)
[PaddleHub安装指南](https://github.com/PaddlePaddle/PaddleHub)
#### 安装代码
......@@ -169,4 +172,3 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
......@@ -201,7 +201,7 @@ def main(args):
pyreader_name='train_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -233,7 +233,7 @@ def main(args):
pyreader_name='eval_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
# user defined model based on ernie embeddings
loss, accuracy, num_seqs = create_model(
......@@ -253,7 +253,7 @@ def main(args):
pyreader_name='infer_reader')
# get ernie_embeddings
embeddings = ernie.ernie_encoder(ernie_inputs, ernie_config=ernie_config)
embeddings = ernie.ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
probs = create_model(args,
embeddings,
......@@ -268,7 +268,7 @@ def main(args):
utils.init_checkpoint(
exe,
args.init_checkpoint,
main_program=startup_prog)
main_program=train_program)
elif args.do_val or args.do_infer:
if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if"
......
......@@ -10,6 +10,7 @@ import json
import six
import paddle.fluid as fluid
import paddlehub as hub
from models.transformer_encoder import encoder, pre_process_layer
......@@ -26,7 +27,8 @@ def ernie_pyreader(args, pyreader_name):
name=pyreader_name,
use_double_buffer=True)
(src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens) = fluid.layers.read_file(pyreader)
(src_ids, sent_ids, pos_ids, input_mask, labels,
seq_lens) = fluid.layers.read_file(pyreader)
ernie_inputs = {
"src_ids": src_ids,
......@@ -38,6 +40,42 @@ def ernie_pyreader(args, pyreader_name):
return pyreader, ernie_inputs, labels
def ernie_encoder_with_paddle_hub(ernie_inputs, max_seq_len):
ernie = hub.Module(name="ernie")
inputs, outputs, program = ernie.context(
trainable=True, max_seq_len=max_seq_len, learning_rate=1)
main_program = fluid.default_main_program()
input_dict = {
inputs["input_ids"].name: ernie_inputs["src_ids"],
inputs["segment_ids"].name: ernie_inputs["sent_ids"],
inputs["position_ids"].name: ernie_inputs["pos_ids"],
inputs["input_mask"].name: ernie_inputs["input_mask"]
}
hub.connect_program(
pre_program=main_program,
next_program=program,
input_dict=input_dict,
inplace=True,
need_log=False)
enc_out = outputs["sequence_output"]
unpad_enc_out = fluid.layers.sequence_unpad(
enc_out, length=ernie_inputs["seq_lens"])
cls_feats = outputs["pooled_output"]
embeddings = {
"sentence_embeddings": cls_feats,
"token_embeddings": unpad_enc_out,
}
for k, v in embeddings.items():
v.persistable = True
return embeddings
def ernie_encoder(ernie_inputs, ernie_config):
"""return sentence embedding and token embeddings"""
......@@ -49,7 +87,8 @@ def ernie_encoder(ernie_inputs, ernie_config):
config=ernie_config)
enc_out = ernie.get_sequence_output()
unpad_enc_out = fluid.layers.sequence_unpad(enc_out, length=ernie_inputs["seq_lens"])
unpad_enc_out = fluid.layers.sequence_unpad(
enc_out, length=ernie_inputs["seq_lens"])
cls_feats = ernie.get_pooled_output()
embeddings = {
......@@ -65,6 +104,7 @@ def ernie_encoder(ernie_inputs, ernie_config):
class ErnieConfig(object):
"""ErnieConfig"""
def __init__(self, config_path):
self._config_dict = self._parse(config_path)
......@@ -90,6 +130,7 @@ class ErnieConfig(object):
class ErnieModel(object):
"""ErnieModel"""
def __init__(self,
src_ids,
position_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册