未验证 提交 673a3cf7 编写于 作者: 骑马小猫 提交者: GitHub

Add return_prob in text classification module

上级 9c1fb388
...@@ -124,8 +124,15 @@ please add WeChat above and send "Hub" to the robot, the robot will invite you t ...@@ -124,8 +124,15 @@ please add WeChat above and send "Hub" to the robot, the robot will invite you t
## QuickStart ## QuickStart
```python ```python
!pip install --upgrade paddlepaddle # install paddlepaddle with gpu
!pip install --upgrade paddlehub # !pip install --upgrade paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
# or install paddlepaddle with cpu
!pip install --upgrade paddlepaddle -i https://mirror.baidu.com/pypi/simple
# install paddlehub
!pip install --upgrade paddlehub -i https://mirror.baidu.com/pypi/simple
import paddlehub as hub import paddlehub as hub
......
...@@ -28,6 +28,6 @@ if __name__ == '__main__': ...@@ -28,6 +28,6 @@ if __name__ == '__main__':
task='seq-cls', task='seq-cls',
load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams', load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams',
label_map=label_map) label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False) results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True)
for idx, text in enumerate(data): for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx])) print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
...@@ -18,7 +18,7 @@ import io ...@@ -18,7 +18,7 @@ import io
import json import json
import os import os
import six import six
from typing import List, Tuple from typing import List, Tuple, Union
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -552,7 +552,8 @@ class TransformerModule(RunModule, TextServing): ...@@ -552,7 +552,8 @@ class TransformerModule(RunModule, TextServing):
max_seq_len: int = 128, max_seq_len: int = 128,
split_char: str = '\002', split_char: str = '\002',
batch_size: int = 1, batch_size: int = 1,
use_gpu: bool = False): use_gpu: bool = False,
return_prob: bool = False):
""" """
Predicts the data labels. Predicts the data labels.
...@@ -563,6 +564,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -563,6 +564,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task. split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch. batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns: Returns:
results(obj:`list`): All the predictions labels. results(obj:`list`): All the predictions labels.
...@@ -579,6 +581,8 @@ class TransformerModule(RunModule, TextServing): ...@@ -579,6 +581,8 @@ class TransformerModule(RunModule, TextServing):
batches = self._batchify(data, max_seq_len, batch_size, split_char) batches = self._batchify(data, max_seq_len, batch_size, split_char)
results = [] results = []
batch_probs = []
self.eval() self.eval()
for batch in batches: for batch in batches:
if self.task == 'text-matching': if self.task == 'text-matching':
...@@ -589,32 +593,38 @@ class TransformerModule(RunModule, TextServing): ...@@ -589,32 +593,38 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids = paddle.to_tensor(title_segment_ids) title_segment_ids = paddle.to_tensor(title_segment_ids)
probs = self(query_input_ids=query_input_ids, query_token_type_ids=query_segment_ids, \ probs = self(query_input_ids=query_input_ids, query_token_type_ids=query_segment_ids, \
title_input_ids=title_input_ids, title_token_type_ids=title_segment_ids) title_input_ids=title_input_ids, title_token_type_ids=title_segment_ids)
idx = paddle.argmax(probs, axis=1).numpy() idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist() idx = idx.tolist()
labels = [self.label_map[i] for i in idx] labels = [self.label_map[i] for i in idx]
results.extend(labels)
else: else:
input_ids, segment_ids = batch input_ids, segment_ids = batch
input_ids = paddle.to_tensor(input_ids) input_ids = paddle.to_tensor(input_ids)
segment_ids = paddle.to_tensor(segment_ids) segment_ids = paddle.to_tensor(segment_ids)
if self.task == 'seq-cls': if self.task == 'seq-cls':
probs = self(input_ids, segment_ids) probs = self(input_ids, segment_ids)
idx = paddle.argmax(probs, axis=1).numpy() idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist() idx = idx.tolist()
labels = [self.label_map[i] for i in idx] labels = [self.label_map[i] for i in idx]
results.extend(labels)
elif self.task == 'token-cls': elif self.task == 'token-cls':
probs = self(input_ids, segment_ids) probs = self(input_ids, segment_ids)
batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len) batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len)
batch_ids = batch_ids.tolist() batch_ids = batch_ids.tolist()
token_labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids] # token labels
results.extend(token_labels) labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids]
elif self.task == None: elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids) sequence_output, pooled_output = self(input_ids, segment_ids)
results.append( results.append(
[pooled_output.squeeze(0).numpy().tolist(), [pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()]) sequence_output.squeeze(0).numpy().tolist()])
if self.task:
# save probs only when return prob
if return_prob:
batch_probs.extend(probs.numpy().tolist())
results.extend(labels)
if self.task and return_prob:
return results, batch_probs
return results return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册