未验证 提交 6bbe4dea 编写于 作者: K KP 提交者: GitHub

Add text processing of token-cls task in predict method (#1199)

* Reset metric value at training and validation step

* Add text processing of token-cls task in predict method
上级 5c4f756f
......@@ -23,7 +23,7 @@ from paddlehub.env import DATA_HOME
from paddlehub.text.bert_tokenizer import BertTokenizer
from paddlehub.text.tokenizer import CustomTokenizer
from paddlehub.utils.log import logger
from paddlehub.utils.utils import download
from paddlehub.utils.utils import download, reseg_token_label
from paddlehub.utils.xarfile import is_xarfile, unarchive
......@@ -309,7 +309,8 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
"""
records = []
for example in examples:
tokens, labels = self._reseg_token_label(
tokens, labels = reseg_token_label(
tokenizer=self.tokenizer,
tokens=example.text_a.split(self.split_char),
labels=example.label.split(self.split_char))
record = self.tokenizer.encode(
......@@ -339,42 +340,6 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
records.append(record)
return records
def _reseg_token_label(
self, tokens: List[str], labels: List[str] = None) -> Tuple[List[str], List[str]] or List[str]:
if labels:
if len(tokens) != len(labels):
raise ValueError(
"The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = self.tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(ret_labels):
raise ValueError(
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
else:
ret_tokens = []
for token in tokens:
sub_token = self.tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens, None
def __getitem__(self, idx):
record = self.records[idx]
if 'label' in record.keys():
......
......@@ -30,6 +30,7 @@ from paddle.utils.download import get_path_from_url
from paddlehub.module.module import serving, RunModule, runnable
from paddlehub.utils.log import logger
from paddlehub.utils.utils import reseg_token_label
__all__ = [
'PretrainedModel',
......@@ -411,8 +412,12 @@ class TransformerModule(RunModule, TextServing):
'token-cls',
]
def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int):
def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int, split_char: str):
pad_to_max_seq_len = False if self.task is None else True
if self.task == 'token-cls': # Extra processing of token-cls task
tokens = text[0].split(split_char)
text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens)
if len(text) == 1:
encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
elif len(text) == 2:
......@@ -422,7 +427,7 @@ class TransformerModule(RunModule, TextServing):
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text))
return encoded_inputs
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int):
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
def _parse_batch(batch):
input_ids = [entry[0] for entry in batch]
segment_ids = [entry[1] for entry in batch]
......@@ -431,7 +436,7 @@ class TransformerModule(RunModule, TextServing):
tokenizer = self.get_tokenizer()
examples = []
for text in data:
encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len)
encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char)
examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids']))
# Seperates data into some batches.
......@@ -459,6 +464,7 @@ class TransformerModule(RunModule, TextServing):
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric}
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
......@@ -475,6 +481,7 @@ class TransformerModule(RunModule, TextServing):
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset()
return {'metrics': metric}
def get_embedding(self, data: List[List[str]], use_gpu=False):
......@@ -499,6 +506,7 @@ class TransformerModule(RunModule, TextServing):
self,
data: List[List[str]],
max_seq_len: int = 128,
split_char: str = '\002',
batch_size: int = 1,
use_gpu: bool = False
):
......@@ -509,6 +517,7 @@ class TransformerModule(RunModule, TextServing):
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
If set to a number, will limit the total sequence returned so that it has a maximum length.
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
......@@ -526,7 +535,7 @@ class TransformerModule(RunModule, TextServing):
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, max_seq_len, batch_size)
batches = self._batchify(data, max_seq_len, batch_size, split_char)
results = []
self.eval()
for batch in batches:
......
......@@ -27,7 +27,7 @@ import time
import tempfile
import traceback
import types
from typing import Generator
from typing import Generator, List
from urllib.parse import urlparse
import numpy as np
......@@ -327,3 +327,43 @@ def mkdir(path: str):
"""The same as the shell command `mkdir -p`."""
if not os.path.exists(path):
os.makedirs(path)
def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
'''
Convert segments and labels of sequence labeling samples into tokens
based on the vocab of tokenizer.
'''
if labels:
if len(tokens) != len(labels):
raise ValueError(
"The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(ret_labels):
raise ValueError(
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
else:
ret_tokens = []
for token in tokens:
sub_token = tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens, None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册