未验证 提交 6bbe4dea 编写于 作者: K KP 提交者: GitHub

Add text processing of token-cls task in predict method (#1199)

* Reset metric value at training and validation step

* Add text processing of token-cls task in predict method
上级 5c4f756f
...@@ -23,7 +23,7 @@ from paddlehub.env import DATA_HOME ...@@ -23,7 +23,7 @@ from paddlehub.env import DATA_HOME
from paddlehub.text.bert_tokenizer import BertTokenizer from paddlehub.text.bert_tokenizer import BertTokenizer
from paddlehub.text.tokenizer import CustomTokenizer from paddlehub.text.tokenizer import CustomTokenizer
from paddlehub.utils.log import logger from paddlehub.utils.log import logger
from paddlehub.utils.utils import download from paddlehub.utils.utils import download, reseg_token_label
from paddlehub.utils.xarfile import is_xarfile, unarchive from paddlehub.utils.xarfile import is_xarfile, unarchive
...@@ -309,7 +309,8 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -309,7 +309,8 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
""" """
records = [] records = []
for example in examples: for example in examples:
tokens, labels = self._reseg_token_label( tokens, labels = reseg_token_label(
tokenizer=self.tokenizer,
tokens=example.text_a.split(self.split_char), tokens=example.text_a.split(self.split_char),
labels=example.label.split(self.split_char)) labels=example.label.split(self.split_char))
record = self.tokenizer.encode( record = self.tokenizer.encode(
...@@ -339,42 +340,6 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -339,42 +340,6 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
records.append(record) records.append(record)
return records return records
def _reseg_token_label(
self, tokens: List[str], labels: List[str] = None) -> Tuple[List[str], List[str]] or List[str]:
if labels:
if len(tokens) != len(labels):
raise ValueError(
"The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = self.tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(ret_labels):
raise ValueError(
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
else:
ret_tokens = []
for token in tokens:
sub_token = self.tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens, None
def __getitem__(self, idx): def __getitem__(self, idx):
record = self.records[idx] record = self.records[idx]
if 'label' in record.keys(): if 'label' in record.keys():
......
...@@ -30,6 +30,7 @@ from paddle.utils.download import get_path_from_url ...@@ -30,6 +30,7 @@ from paddle.utils.download import get_path_from_url
from paddlehub.module.module import serving, RunModule, runnable from paddlehub.module.module import serving, RunModule, runnable
from paddlehub.utils.log import logger from paddlehub.utils.log import logger
from paddlehub.utils.utils import reseg_token_label
__all__ = [ __all__ = [
'PretrainedModel', 'PretrainedModel',
...@@ -411,8 +412,12 @@ class TransformerModule(RunModule, TextServing): ...@@ -411,8 +412,12 @@ class TransformerModule(RunModule, TextServing):
'token-cls', 'token-cls',
] ]
def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int): def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int, split_char: str):
pad_to_max_seq_len = False if self.task is None else True pad_to_max_seq_len = False if self.task is None else True
if self.task == 'token-cls': # Extra processing of token-cls task
tokens = text[0].split(split_char)
text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens)
if len(text) == 1: if len(text) == 1:
encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
elif len(text) == 2: elif len(text) == 2:
...@@ -422,7 +427,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -422,7 +427,7 @@ class TransformerModule(RunModule, TextServing):
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text))
return encoded_inputs return encoded_inputs
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int): def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
def _parse_batch(batch): def _parse_batch(batch):
input_ids = [entry[0] for entry in batch] input_ids = [entry[0] for entry in batch]
segment_ids = [entry[1] for entry in batch] segment_ids = [entry[1] for entry in batch]
...@@ -431,7 +436,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -431,7 +436,7 @@ class TransformerModule(RunModule, TextServing):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
examples = [] examples = []
for text in data: for text in data:
encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len) encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char)
examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids'])) examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids']))
# Seperates data into some batches. # Seperates data into some batches.
...@@ -459,6 +464,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -459,6 +464,7 @@ class TransformerModule(RunModule, TextServing):
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls': elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric} return {'loss': avg_loss, 'metrics': metric}
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int): def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
...@@ -475,6 +481,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -475,6 +481,7 @@ class TransformerModule(RunModule, TextServing):
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls': elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset()
return {'metrics': metric} return {'metrics': metric}
def get_embedding(self, data: List[List[str]], use_gpu=False): def get_embedding(self, data: List[List[str]], use_gpu=False):
...@@ -499,6 +506,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -499,6 +506,7 @@ class TransformerModule(RunModule, TextServing):
self, self,
data: List[List[str]], data: List[List[str]],
max_seq_len: int = 128, max_seq_len: int = 128,
split_char: str = '\002',
batch_size: int = 1, batch_size: int = 1,
use_gpu: bool = False use_gpu: bool = False
): ):
...@@ -509,6 +517,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -509,6 +517,7 @@ class TransformerModule(RunModule, TextServing):
data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts. data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts.
max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`): max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`):
If set to a number, will limit the total sequence returned so that it has a maximum length. If set to a number, will limit the total sequence returned so that it has a maximum length.
split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task.
batch_size(obj:`int`, defaults to 1): The number of batch. batch_size(obj:`int`, defaults to 1): The number of batch.
use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not.
...@@ -526,7 +535,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -526,7 +535,7 @@ class TransformerModule(RunModule, TextServing):
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, max_seq_len, batch_size) batches = self._batchify(data, max_seq_len, batch_size, split_char)
results = [] results = []
self.eval() self.eval()
for batch in batches: for batch in batches:
......
...@@ -27,7 +27,7 @@ import time ...@@ -27,7 +27,7 @@ import time
import tempfile import tempfile
import traceback import traceback
import types import types
from typing import Generator from typing import Generator, List
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -327,3 +327,43 @@ def mkdir(path: str): ...@@ -327,3 +327,43 @@ def mkdir(path: str):
"""The same as the shell command `mkdir -p`.""" """The same as the shell command `mkdir -p`."""
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
'''
Convert segments and labels of sequence labeling samples into tokens
based on the vocab of tokenizer.
'''
if labels:
if len(tokens) != len(labels):
raise ValueError(
"The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(ret_labels):
raise ValueError(
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
else:
ret_tokens = []
for token in tokens:
sub_token = tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens, None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册