Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
36d7d800
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36d7d800
编写于
3月 12, 2020
作者:
S
Steffy-zxf
提交者:
GitHub
3月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add nlp module (#433)
* add nlpmodule base class
上级
1e5be079
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
212 addition
and
36 deletion
+212
-36
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/module/module.py
paddlehub/module/module.py
+17
-9
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+194
-26
未找到文件。
paddlehub/__init__.py
浏览文件 @
36d7d800
...
...
@@ -64,4 +64,4 @@ from .finetune.strategy import CombinedStrategy
from
.autofinetune.evaluator
import
report_final_result
from
.module.nlp_module
import
BERT
Module
from
.module.nlp_module
import
NLPPredictionModule
,
Transformer
Module
paddlehub/module/module.py
浏览文件 @
36d7d800
#coding:utf-8
#
coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
...
...
@@ -135,19 +135,27 @@ class Module(object):
if
"_is_initialize"
in
self
.
__dict__
and
self
.
_is_initialize
:
return
mod
=
self
.
__class__
.
__module__
+
"."
+
self
.
__class__
.
__name__
if
mod
in
_module_runnable_func
:
_run_func_name
=
_module_runnable_func
[
mod
]
self
.
_run_func
=
getattr
(
self
,
_run_func_name
)
else
:
self
.
_run_func
=
None
self
.
_serving_func_name
=
_module_serving_func
.
get
(
mod
,
None
)
self
.
_code_version
=
"v2"
_run_func_name
=
self
.
_get_func_name
(
self
.
__class__
,
_module_runnable_func
)
self
.
_run_func
=
getattr
(
self
,
_run_func_name
)
self
.
_serving_func_name
=
self
.
_get_func_name
(
self
.
__class__
,
_module_serving_func
)
self
.
_directory
=
directory
self
.
_initialize
(
**
kwargs
)
self
.
_is_initialize
=
True
self
.
_code_version
=
"v2"
def
_get_func_name
(
self
,
current_cls
,
module_func_dict
):
mod
=
current_cls
.
__module__
+
"."
+
current_cls
.
__name__
if
mod
in
module_func_dict
:
_func_name
=
module_func_dict
[
mod
]
return
_func_name
elif
current_cls
.
__bases__
:
for
base_class
in
current_cls
.
__bases__
:
return
self
.
_get_func_name
(
base_class
,
module_func_dict
)
else
:
return
None
@
classmethod
def
init_with_name
(
cls
,
name
,
version
=
None
,
**
kwargs
):
fp_lock
=
open
(
os
.
path
.
join
(
CACHE_HOME
,
name
),
"a"
)
...
...
paddlehub/module/nlp_module.py
浏览文件 @
36d7d800
#coding:utf-8
#
coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
...
...
@@ -17,15 +17,198 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
ast
import
json
import
os
import
re
import
six
import
paddlehub
as
hub
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddlehub
import
logger
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
import
paddlehub
as
hub
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
sys_stdin_encoding
from
paddlehub.io.parser
import
txt_parser
from
paddlehub.module.module
import
runnable
class
DataFormatError
(
Exception
):
def
__init__
(
self
,
*
args
):
self
.
args
=
args
class
NLPBaseModule
(
hub
.
Module
):
def
_initialize
(
self
):
"""
initialize with the necessary elements
This method must be overrided.
"""
raise
NotImplementedError
()
def
get_vocab_path
(
self
):
"""
Get the path to the vocabulary whih was used to pretrain
Returns:
self.vocab_path(str): the path to vocabulary
"""
return
self
.
vocab_path
class
NLPPredictionModule
(
NLPBaseModule
):
def
_set_config
(
self
):
"""
predictor config setting
"""
cpu_config
=
AnalysisConfig
(
self
.
pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
self
.
cpu_predictor
=
create_paddle_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
def
texts2tensor
(
self
,
texts
):
"""
Tranform the texts(dict) to PaddleTensor
Args:
texts(list): each element is a dict that must have a named 'processed' key whose value is word_ids, such as
texts = [{'processed': [23, 89, 43, 906]}]
Returns:
tensor(PaddleTensor): tensor with texts data
"""
lod
=
[
0
]
data
=
[]
for
i
,
text
in
enumerate
(
texts
):
data
+=
text
[
'processed'
]
lod
.
append
(
len
(
text
[
'processed'
])
+
lod
[
i
])
tensor
=
PaddleTensor
(
np
.
array
(
data
).
astype
(
'int64'
))
tensor
.
name
=
"words"
tensor
.
lod
=
[
lod
]
tensor
.
shape
=
[
lod
[
-
1
],
1
]
return
tensor
def
to_unicode
(
self
,
texts
):
"""
Convert each element's type(str) of texts(list) to unicode in python2.7
Args:
texts(list): each element's type is str in python2.7
Returns:
texts(list): each element's type is unicode in python2.7
"""
if
six
.
PY2
:
unicode_texts
=
[]
for
text
in
texts
:
if
not
isinstance
(
text
,
six
.
string_types
):
unicode_texts
.
append
(
text
.
decode
(
sys_stdin_encoding
()).
decode
(
"utf8"
))
else
:
unicode_texts
.
append
(
text
)
texts
=
unicode_texts
return
texts
@
runnable
def
run_cmd
(
self
,
argvs
):
"""
Run as a command
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
'Run the %s module.'
%
self
.
module_name
,
prog
=
'hub run %s'
%
self
.
module_name
,
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
try
:
input_data
=
self
.
check_input_data
(
args
)
except
DataFormatError
and
RuntimeError
:
self
.
parser
.
print_help
()
return
None
results
=
self
.
predict
(
texts
=
input_data
,
use_gpu
=
args
.
use_gpu
,
batch_size
=
args
.
batch_size
)
return
results
class
_BERTEmbeddingTask
(
hub
.
BaseTask
):
def
add_module_config_arg
(
self
):
"""
Add the command config options
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
"batch size for prediction"
)
def
add_module_input_arg
(
self
):
"""
Add the command input options
"""
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
default
=
None
,
help
=
"file contain input data"
)
self
.
arg_input_group
.
add_argument
(
'--input_text'
,
type
=
str
,
default
=
None
,
help
=
"text to predict"
)
def
check_input_data
(
self
,
args
):
input_data
=
[]
if
args
.
input_file
:
if
not
os
.
path
.
exists
(
args
.
input_file
):
print
(
"File %s is not exist."
%
args
.
input_file
)
raise
RuntimeError
else
:
input_data
=
txt_parser
.
parse
(
args
.
input_file
,
use_strip
=
True
)
elif
args
.
input_text
:
if
args
.
input_text
.
strip
()
!=
''
:
if
six
.
PY2
:
input_data
=
[
args
.
input_text
.
decode
(
sys_stdin_encoding
()).
decode
(
"utf8"
)
]
else
:
input_data
=
[
args
.
input_text
]
else
:
print
(
"ERROR: The input data is inconsistent with expectations."
)
if
input_data
==
[]:
print
(
"ERROR: The input data is inconsistent with expectations."
)
raise
DataFormatError
return
input_data
class
_TransformerEmbeddingTask
(
hub
.
BaseTask
):
def
__init__
(
self
,
pooled_feature
,
seq_feature
,
...
...
@@ -33,7 +216,7 @@ class _BERTEmbeddingTask(hub.BaseTask):
data_reader
,
config
=
None
):
main_program
=
pooled_feature
.
block
.
program
super
(
_
BERT
EmbeddingTask
,
self
).
__init__
(
super
(
_
Transformer
EmbeddingTask
,
self
).
__init__
(
main_program
=
main_program
,
data_reader
=
data_reader
,
feed_list
=
feed_list
,
...
...
@@ -57,21 +240,10 @@ class _BERTEmbeddingTask(hub.BaseTask):
return
results
class
BERTModule
(
hub
.
Module
):
def
_initialize
(
self
):
"""
Must override this method.
some member variables are required, others are optional.
"""
# required config
self
.
MAX_SEQ_LEN
=
None
self
.
params_path
=
None
self
.
vocab_path
=
None
# optional config
self
.
spm_path
=
None
self
.
word_dict_path
=
None
raise
NotImplementedError
class
TransformerModule
(
NLPBaseModule
):
"""
Tranformer Module base class can be used by BERT, ERNIE, RoBERTa and so on.
"""
def
init_pretraining_params
(
self
,
exe
,
pretraining_params_path
,
main_program
):
...
...
@@ -157,7 +329,6 @@ class BERTModule(hub.Module):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
self
.
init_pretraining_params
(
exe
,
self
.
params_path
,
main_program
=
startup_program
)
...
...
@@ -176,7 +347,7 @@ class BERTModule(hub.Module):
def
get_embedding
(
self
,
texts
,
use_gpu
=
False
,
batch_size
=
1
):
"""
get pooled_output and sequence_output for input texts.
Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle < 1.6.2.
Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle <
=
1.6.2.
Args:
texts (list): each element is a text sample, each sample include text_a and text_b where text_b can be omitted.
...
...
@@ -220,7 +391,7 @@ class BERTModule(hub.Module):
batch_size
=
batch_size
)
self
.
emb_job
=
{}
self
.
emb_job
[
"task"
]
=
_
BERT
EmbeddingTask
(
self
.
emb_job
[
"task"
]
=
_
Transformer
EmbeddingTask
(
pooled_feature
=
pooled_feature
,
seq_feature
=
seq_feature
,
feed_list
=
feed_list
,
...
...
@@ -233,9 +404,6 @@ class BERTModule(hub.Module):
return
self
.
emb_job
[
"task"
].
predict
(
data
=
texts
,
return_result
=
True
,
accelerate_mode
=
True
)
def
get_vocab_path
(
self
):
return
self
.
vocab_path
def
get_spm_path
(
self
):
if
hasattr
(
self
,
"spm_path"
):
return
self
.
spm_path
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录