未验证 提交 c8d399db 编写于 作者: K KP 提交者: GitHub

Cache tokenizer in TransformerModule (#1491)

上级 8c304a76
......@@ -38,7 +38,7 @@ from .decode import beam_search_infilling
type="nlp/text_generation",
)
class ErnieGen(hub.NLPPredictionModule):
def _initialize(self):
def __init__(self):
"""
initialize with the necessary elements
"""
......@@ -66,6 +66,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns:
results(list): the predict result.
"""
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]):
predicted_data = texts
......@@ -79,7 +81,9 @@ class ErnieGen(hub.NLPPredictionModule):
logger.warning(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
self.model.eval()
results = []
for text in predicted_data:
......@@ -155,4 +159,4 @@ class ErnieGen(hub.NLPPredictionModule):
results = self.generate(
texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width)
return results
\ No newline at end of file
return results
......@@ -91,7 +91,6 @@ class InitTrackerMeta(type(nn.Layer)):
help_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`.
Default None.
Returns:
function: the wrapped function
"""
......@@ -142,7 +141,6 @@ class PretrainedModel(nn.Layer):
- `pretrained_init_configuration` (dict): The dict has pretrained model names
as keys, and the values are also dict preserving corresponding configuration
for model initialization.
- `base_model_prefix` (str): represents the the attribute associated to the
base model in derived classes of the same architecture adding layers on
top of the base model.
......@@ -365,14 +363,12 @@ class TextServing(object):
1. seq-cls: sequence classification;
2. token-cls: sequence labeling;
3. None: embedding.
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to 128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns:
results(obj:`list`): All the predictions labels.
"""
......@@ -465,11 +461,12 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids = [entry[3] for entry in batch]
return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids
tokenizer = self.get_tokenizer()
if not hasattr(self, 'tokenizer'):
self.tokenizer = self.get_tokenizer()
examples = []
for texts in data:
encoded_inputs = self._convert_text_to_input(tokenizer, texts, max_seq_len, split_char)
encoded_inputs = self._convert_text_to_input(self.tokenizer, texts, max_seq_len, split_char)
example = []
for inp in encoded_inputs:
input_ids = inp['input_ids']
......@@ -538,7 +535,6 @@ class TransformerModule(RunModule, TextServing):
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns:
results(obj:`list`): All the tokens and sentences embeddings.
"""
......@@ -556,7 +552,6 @@ class TransformerModule(RunModule, TextServing):
return_prob: bool = False):
"""
Predicts the data labels.
Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
......@@ -564,8 +559,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns:
results(obj:`list`): All the predictions labels.
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册