未验证 提交 673a3cf7 编写于 作者: 骑马小猫 提交者: GitHub

Add return_prob in text classification module

上级 9c1fb388
......@@ -124,8 +124,15 @@ please add WeChat above and send "Hub" to the robot, the robot will invite you t
## QuickStart
```python
!pip install --upgrade paddlepaddle
!pip install --upgrade paddlehub
# install paddlepaddle with gpu
# !pip install --upgrade paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
# or install paddlepaddle with cpu
!pip install --upgrade paddlepaddle -i https://mirror.baidu.com/pypi/simple
# install paddlehub
!pip install --upgrade paddlehub -i https://mirror.baidu.com/pypi/simple
import paddlehub as hub
......
......@@ -28,6 +28,6 @@ if __name__ == '__main__':
task='seq-cls',
load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams',
label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False)
results, probs = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=False, return_prob=True)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
......@@ -18,7 +18,7 @@ import io
import json
import os
import six
from typing import List, Tuple
from typing import List, Tuple, Union
import paddle
import paddle.nn as nn
......@@ -552,7 +552,8 @@ class TransformerModule(RunModule, TextServing):
max_seq_len: int = 128,
split_char: str = '\002',
batch_size: int = 1,
use_gpu: bool = False):
use_gpu: bool = False,
return_prob: bool = False):
"""
Predicts the data labels.
......@@ -563,6 +564,7 @@ class TransformerModule(RunModule, TextServing):
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
return_prob(obj:`bool`, defaults to `False`): Whether to return label probabilities.
Returns:
results(obj:`list`): All the predictions labels.
......@@ -579,6 +581,8 @@ class TransformerModule(RunModule, TextServing):
batches = self._batchify(data, max_seq_len, batch_size, split_char)
results = []
batch_probs = []
self.eval()
for batch in batches:
if self.task == 'text-matching':
......@@ -589,32 +593,38 @@ class TransformerModule(RunModule, TextServing):
title_segment_ids = paddle.to_tensor(title_segment_ids)
probs = self(query_input_ids=query_input_ids, query_token_type_ids=query_segment_ids, \
title_input_ids=title_input_ids, title_token_type_ids=title_segment_ids)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [self.label_map[i] for i in idx]
results.extend(labels)
else:
input_ids, segment_ids = batch
input_ids = paddle.to_tensor(input_ids)
segment_ids = paddle.to_tensor(segment_ids)
if self.task == 'seq-cls':
probs = self(input_ids, segment_ids)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [self.label_map[i] for i in idx]
results.extend(labels)
elif self.task == 'token-cls':
probs = self(input_ids, segment_ids)
batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len)
batch_ids = batch_ids.tolist()
token_labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids]
results.extend(token_labels)
# token labels
labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids]
elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids)
results.append(
[pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()])
if self.task:
# save probs only when return prob
if return_prob:
batch_probs.extend(probs.numpy().tolist())
results.extend(labels)
if self.task and return_prob:
return results, batch_probs
return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册