未验证 提交 c8d399db 编写于 作者: K KP 提交者: GitHub

Cache tokenizer in TransformerModule (#1491)

上级 8c304a76
...@@ -38,7 +38,7 @@ from .decode import beam_search_infilling ...@@ -38,7 +38,7 @@ from .decode import beam_search_infilling
type="nlp/text_generation", type="nlp/text_generation",
) )
class ErnieGen(hub.NLPPredictionModule): class ErnieGen(hub.NLPPredictionModule):
def _initialize(self): def __init__(self):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
...@@ -66,6 +66,8 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -66,6 +66,8 @@ class ErnieGen(hub.NLPPredictionModule):
Returns: Returns:
results(list): the predict result. results(list): the predict result.
""" """
paddle.disable_static()
if texts and isinstance(texts, list) and all(texts) and all( if texts and isinstance(texts, list) and all(texts) and all(
[isinstance(text, str) for text in texts]): [isinstance(text, str) for text in texts]):
predicted_data = texts predicted_data = texts
...@@ -79,7 +81,9 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -79,7 +81,9 @@ class ErnieGen(hub.NLPPredictionModule):
logger.warning( logger.warning(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True" "use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
) )
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
self.model.eval() self.model.eval()
results = [] results = []
for text in predicted_data: for text in predicted_data:
...@@ -155,4 +159,4 @@ class ErnieGen(hub.NLPPredictionModule): ...@@ -155,4 +159,4 @@ class ErnieGen(hub.NLPPredictionModule):
results = self.generate( results = self.generate(
texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width) texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width)
return results return results
\ No newline at end of file
...@@ -91,7 +91,6 @@ class InitTrackerMeta(type(nn.Layer)): ...@@ -91,7 +91,6 @@ class InitTrackerMeta(type(nn.Layer)):
help_func (callable, optional): If provided, it would be hooked after help_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`. `init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`.
Default None. Default None.
Returns: Returns:
function: the wrapped function function: the wrapped function
""" """
...@@ -142,7 +141,6 @@ class PretrainedModel(nn.Layer): ...@@ -142,7 +141,6 @@ class PretrainedModel(nn.Layer):
- `pretrained_init_configuration` (dict): The dict has pretrained model names - `pretrained_init_configuration` (dict): The dict has pretrained model names
as keys, and the values are also dict preserving corresponding configuration as keys, and the values are also dict preserving corresponding configuration
for model initialization. for model initialization.
- `base_model_prefix` (str): represents the the attribute associated to the - `base_model_prefix` (str): represents the the attribute associated to the
base model in derived classes of the same architecture adding layers on base model in derived classes of the same architecture adding layers on
top of the base model. top of the base model.
...@@ -365,14 +363,12 @@ class TextServing(object): ...@@ -365,14 +363,12 @@ class TextServing(object):
1. seq-cls: sequence classification; 1. seq-cls: sequence classification;
2. token-cls: sequence labeling; 2. token-cls: sequence labeling;
3. None: embedding. 3. None: embedding.
Args: Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts. data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to 128): max_seq_len (:obj:`int`, `optional`, defaults to 128):
If set to a number, will limit the total sequence returned so that it has a maximum length. If set to a number, will limit the total sequence returned so that it has a maximum length.
batch_size(obj:`int`, defaults to 1): The number of batch. batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns: Returns:
results(obj:`list`): All the predictions labels. results(obj:`list`): All the predictions labels.
""" """
...@@ -465,11 +461,12 @@ class TransformerModule(RunModule, TextServing): ...@@ -465,11 +461,12 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids = [entry[3] for entry in batch] title_segment_ids = [entry[3] for entry in batch]
return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids
tokenizer = self.get_tokenizer() if not hasattr(self, 'tokenizer'):
self.tokenizer = self.get_tokenizer()
examples = [] examples = []
for texts in data: for texts in data:
encoded_inputs = self._convert_text_to_input(tokenizer, texts, max_seq_len, split_char) encoded_inputs = self._convert_text_to_input(self.tokenizer, texts, max_seq_len, split_char)
example = [] example = []
for inp in encoded_inputs: for inp in encoded_inputs:
input_ids = inp['input_ids'] input_ids = inp['input_ids']
...@@ -538,7 +535,6 @@ class TransformerModule(RunModule, TextServing): ...@@ -538,7 +535,6 @@ class TransformerModule(RunModule, TextServing):
Args: Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts. data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
Returns: Returns:
results(obj:`list`): All the tokens and sentences embeddings. results(obj:`list`): All the tokens and sentences embeddings.
""" """
...@@ -556,7 +552,6 @@ class TransformerModule(RunModule, TextServing): ...@@ -556,7 +552,6 @@ class TransformerModule(RunModule, TextServing):
return_prob: bool = False): return_prob: bool = False):
""" """
Predicts the data labels. Predicts the data labels.
Args: Args:
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts. data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`): max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
...@@ -564,8 +559,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -564,8 +559,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task. split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch. batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities. return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns: Returns:
results(obj:`list`): All the predictions labels. results(obj:`list`): All the predictions labels.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册