From 6bbe4dea9c1d02724ab8493bd916abf985996c17 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 20 Jan 2021 11:28:37 +0800 Subject: [PATCH] Add text processing of token-cls task in predict method (#1199) * Reset metric value at training and validation step * Add text processing of token-cls task in predict method --- paddlehub/datasets/base_nlp_dataset.py | 41 ++----------------------- paddlehub/module/nlp_module.py | 17 ++++++++--- paddlehub/utils/utils.py | 42 +++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/paddlehub/datasets/base_nlp_dataset.py b/paddlehub/datasets/base_nlp_dataset.py index 1c9ae13a..bee1aa4a 100644 --- a/paddlehub/datasets/base_nlp_dataset.py +++ b/paddlehub/datasets/base_nlp_dataset.py @@ -23,7 +23,7 @@ from paddlehub.env import DATA_HOME from paddlehub.text.bert_tokenizer import BertTokenizer from paddlehub.text.tokenizer import CustomTokenizer from paddlehub.utils.log import logger -from paddlehub.utils.utils import download +from paddlehub.utils.utils import download, reseg_token_label from paddlehub.utils.xarfile import is_xarfile, unarchive @@ -309,7 +309,8 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): """ records = [] for example in examples: - tokens, labels = self._reseg_token_label( + tokens, labels = reseg_token_label( + tokenizer=self.tokenizer, tokens=example.text_a.split(self.split_char), labels=example.label.split(self.split_char)) record = self.tokenizer.encode( @@ -339,42 +340,6 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): records.append(record) return records - def _reseg_token_label( - self, tokens: List[str], labels: List[str] = None) -> Tuple[List[str], List[str]] or List[str]: - if labels: - if len(tokens) != len(labels): - raise ValueError( - "The length of tokens must be same with labels") - ret_tokens = [] - ret_labels = [] - for token, label in zip(tokens, labels): - sub_token = self.tokenizer(token) - if len(sub_token) == 0: - continue - ret_tokens.extend(sub_token) - ret_labels.append(label) - if len(sub_token) < 2: - continue - sub_label = label - if label.startswith("B-"): - sub_label = "I-" + label[2:] - ret_labels.extend([sub_label] * (len(sub_token) - 1)) - - if len(ret_tokens) != len(ret_labels): - raise ValueError( - "The length of ret_tokens can't match with labels") - return ret_tokens, ret_labels - else: - ret_tokens = [] - for token in tokens: - sub_token = self.tokenizer(token) - if len(sub_token) == 0: - continue - ret_tokens.extend(sub_token) - if len(sub_token) < 2: - continue - return ret_tokens, None - def __getitem__(self, idx): record = self.records[idx] if 'label' in record.keys(): diff --git a/paddlehub/module/nlp_module.py b/paddlehub/module/nlp_module.py index ebebed70..80306ffc 100644 --- a/paddlehub/module/nlp_module.py +++ b/paddlehub/module/nlp_module.py @@ -30,6 +30,7 @@ from paddle.utils.download import get_path_from_url from paddlehub.module.module import serving, RunModule, runnable from paddlehub.utils.log import logger +from paddlehub.utils.utils import reseg_token_label __all__ = [ 'PretrainedModel', @@ -411,8 +412,12 @@ class TransformerModule(RunModule, TextServing): 'token-cls', ] - def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int): + def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int, split_char: str): pad_to_max_seq_len = False if self.task is None else True + if self.task == 'token-cls': # Extra processing of token-cls task + tokens = text[0].split(split_char) + text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens) + if len(text) == 1: encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) elif len(text) == 2: @@ -422,7 +427,7 @@ class TransformerModule(RunModule, TextServing): 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) return encoded_inputs - def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int): + def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str): def _parse_batch(batch): input_ids = [entry[0] for entry in batch] segment_ids = [entry[1] for entry in batch] @@ -431,7 +436,7 @@ class TransformerModule(RunModule, TextServing): tokenizer = self.get_tokenizer() examples = [] for text in data: - encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len) + encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char) examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids'])) # Seperates data into some batches. @@ -459,6 +464,7 @@ class TransformerModule(RunModule, TextServing): predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) elif self.task == 'token-cls': predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) + self.metric.reset() return {'loss': avg_loss, 'metrics': metric} def validation_step(self, batch: List[paddle.Tensor], batch_idx: int): @@ -475,6 +481,7 @@ class TransformerModule(RunModule, TextServing): predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) elif self.task == 'token-cls': predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) + self.metric.reset() return {'metrics': metric} def get_embedding(self, data: List[List[str]], use_gpu=False): @@ -499,6 +506,7 @@ class TransformerModule(RunModule, TextServing): self, data: List[List[str]], max_seq_len: int = 128, + split_char: str = '\002', batch_size: int = 1, use_gpu: bool = False ): @@ -509,6 +517,7 @@ class TransformerModule(RunModule, TextServing): data (obj:`List(List(str))`): The processed data whose each element is the list of a single text or a pair of texts. max_seq_len (:obj:`int`, `optional`, defaults to :int:`None`): If set to a number, will limit the total sequence returned so that it has a maximum length. + split_char(obj:`str`, defaults to '\002'): The char used to split input tokens in token-cls task. batch_size(obj:`int`, defaults to 1): The number of batch. use_gpu(obj:`bool`, defaults to `False`): Whether to use gpu to run or not. @@ -526,7 +535,7 @@ class TransformerModule(RunModule, TextServing): paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') - batches = self._batchify(data, max_seq_len, batch_size) + batches = self._batchify(data, max_seq_len, batch_size, split_char) results = [] self.eval() for batch in batches: diff --git a/paddlehub/utils/utils.py b/paddlehub/utils/utils.py index 1839b3dd..0aa6337c 100644 --- a/paddlehub/utils/utils.py +++ b/paddlehub/utils/utils.py @@ -27,7 +27,7 @@ import time import tempfile import traceback import types -from typing import Generator +from typing import Generator, List from urllib.parse import urlparse import numpy as np @@ -327,3 +327,43 @@ def mkdir(path: str): """The same as the shell command `mkdir -p`.""" if not os.path.exists(path): os.makedirs(path) + + +def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None): + ''' + Convert segments and labels of sequence labeling samples into tokens + based on the vocab of tokenizer. + ''' + if labels: + if len(tokens) != len(labels): + raise ValueError( + "The length of tokens must be same with labels") + ret_tokens = [] + ret_labels = [] + for token, label in zip(tokens, labels): + sub_token = tokenizer(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + ret_labels.append(label) + if len(sub_token) < 2: + continue + sub_label = label + if label.startswith("B-"): + sub_label = "I-" + label[2:] + ret_labels.extend([sub_label] * (len(sub_token) - 1)) + + if len(ret_tokens) != len(ret_labels): + raise ValueError( + "The length of ret_tokens can't match with labels") + return ret_tokens, ret_labels + else: + ret_tokens = [] + for token in tokens: + sub_token = tokenizer(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + if len(sub_token) < 2: + continue + return ret_tokens, None -- GitLab