Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
c8d399db
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c8d399db
编写于
7月 01, 2021
作者:
K
KP
提交者:
GitHub
7月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cache tokenizer in TransformerModule (#1491)
上级
8c304a76
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
12 deletion
+10
-12
modules/text/text_generation/ernie_gen/template/module.temp
modules/text/text_generation/ernie_gen/template/module.temp
+6
-2
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+4
-10
未找到文件。
modules/text/text_generation/ernie_gen/template/module.temp
浏览文件 @
c8d399db
...
...
@@ -38,7 +38,7 @@ from .decode import beam_search_infilling
type="nlp/text_generation",
)
class ErnieGen(hub.NLPPredictionModule):
def _
initialize
(self):
def _
_init__
(self):
"""
initialize with the necessary elements
"""
...
...
@@ -66,6 +66,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the predict result.
"""
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
...
...
@@ -79,7 +81,9 @@ class ErnieGen(hub.NLPPredictionModule):
logger.warning(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
self.model.eval()
results = []
for text in predicted_data:
...
...
@@ -155,4 +159,4 @@ class ErnieGen(hub.NLPPredictionModule):
results = self.generate(
texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width)
return results
\ No newline at end of file
return results
paddlehub/module/nlp_module.py
浏览文件 @
c8d399db
...
...
@@ -91,7 +91,6 @@ class InitTrackerMeta(type(nn.Layer)):
help_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`.
Default None.
Returns:
function: the wrapped function
"""
...
...
@@ -142,7 +141,6 @@ class PretrainedModel(nn.Layer):
- `pretrained_init_configuration` (dict): The dict has pretrained model names
as keys, and the values are also dict preserving corresponding configuration
for model initialization.
- `base_model_prefix` (str): represents the the attribute associated to the
base model in derived classes of the same architecture adding layers on
top of the base model.
...
...
@@ -365,14 +363,12 @@ class TextServing(object):
1. seq-cls: sequence classification;
2. token-cls: sequence labeling;
3. None: embedding.
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to 128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns:
results(obj:`list`): All the predictions labels.
"""
...
...
@@ -465,11 +461,12 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids
=
[
entry
[
3
]
for
entry
in
batch
]
return
query_input_ids
,
query_segment_ids
,
title_input_ids
,
title_segment_ids
tokenizer
=
self
.
get_tokenizer
()
if
not
hasattr
(
self
,
'tokenizer'
):
self
.
tokenizer
=
self
.
get_tokenizer
()
examples
=
[]
for
texts
in
data
:
encoded_inputs
=
self
.
_convert_text_to_input
(
tokenizer
,
texts
,
max_seq_len
,
split_char
)
encoded_inputs
=
self
.
_convert_text_to_input
(
self
.
tokenizer
,
texts
,
max_seq_len
,
split_char
)
example
=
[]
for
inp
in
encoded_inputs
:
input_ids
=
inp
[
'input_ids'
]
...
...
@@ -538,7 +535,6 @@ class TransformerModule(RunModule, TextServing):
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns:
results(obj:`list`): All the tokens and sentences embeddings.
"""
...
...
@@ -556,7 +552,6 @@ class TransformerModule(RunModule, TextServing):
return_prob
:
bool
=
False
):
"""
Predicts the data labels.
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
...
...
@@ -564,8 +559,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '
\002
'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns:
results(obj:`list`): All the predictions labels.
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录