未验证 提交 f4d9b46b 编写于 作者: S Steffy-zxf 提交者: GitHub

Fix the compatibility error caused by the upgrade of PretrainedTokenizer

上级 71d0cc9d
...@@ -11,13 +11,15 @@ ...@@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Optional, Union, Tuple
import csv import csv
import io import io
import os import os
from typing import Dict, List, Optional, Union, Tuple
import numpy as np import numpy as np
import paddle import paddle
import paddlenlp
from packaging.version import Version
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from paddlenlp.transformers import PretrainedTokenizer from paddlenlp.transformers import PretrainedTokenizer
...@@ -27,7 +29,6 @@ from paddlehub.utils.utils import download, reseg_token_label, pad_sequence, tru ...@@ -27,7 +29,6 @@ from paddlehub.utils.utils import download, reseg_token_label, pad_sequence, tru
from paddlehub.utils.xarfile import is_xarfile, unarchive from paddlehub.utils.xarfile import is_xarfile, unarchive
class InputExample(object): class InputExample(object):
""" """
The input data structure of Transformer modules (BERT, ERNIE and so on). The input data structure of Transformer modules (BERT, ERNIE and so on).
...@@ -233,7 +234,16 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -233,7 +234,16 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
records = [] records = []
for example in examples: for example in examples:
if isinstance(self.tokenizer, PretrainedTokenizer): if isinstance(self.tokenizer, PretrainedTokenizer):
record = self.tokenizer.encode(text=example.text_a, text_pair=example.text_b, max_seq_len=self.max_seq_len) if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
record = self.tokenizer.encode(
text=example.text_a, text_pair=example.text_b, max_seq_len=self.max_seq_len)
else:
record = self.tokenizer(
text=example.text_a,
text_pair=example.text_b,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=True,
return_length=True)
elif isinstance(self.tokenizer, JiebaTokenizer): elif isinstance(self.tokenizer, JiebaTokenizer):
pad_token = self.tokenizer.vocab.pad_token pad_token = self.tokenizer.vocab.pad_token
...@@ -246,7 +256,9 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -246,7 +256,9 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
ids = pad_sequence(ids, self.max_seq_len, pad_token_id) ids = pad_sequence(ids, self.max_seq_len, pad_token_id)
record = {'text': ids, 'seq_len': seq_len} record = {'text': ids, 'seq_len': seq_len}
else: else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer))) raise RuntimeError(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.format(type(self.tokenizer)))
if not record: if not record:
logger.info( logger.info(
...@@ -260,17 +272,26 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -260,17 +272,26 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
record = self.records[idx] record = self.records[idx]
if isinstance(self.tokenizer, PretrainedTokenizer): if isinstance(self.tokenizer, PretrainedTokenizer):
input_ids = np.array(record['input_ids'])
if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
token_type_ids = np.array(record['token_type_ids'])
else:
token_type_ids = record['segment_ids']
if 'label' in record.keys(): if 'label' in record.keys():
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['label'], dtype=np.int64) return input_ids, token_type_ids, np.array(record['label'], dtype=np.int64)
else: else:
return np.array(record['input_ids']), np.array(record['segment_ids']) return input_ids, token_type_ids
elif isinstance(self.tokenizer, JiebaTokenizer): elif isinstance(self.tokenizer, JiebaTokenizer):
if 'label' in record.keys(): if 'label' in record.keys():
return np.array(record['text']), np.array(record['label'], dtype=np.int64) return np.array(record['text']), np.array(record['label'], dtype=np.int64)
else: else:
return np.array(record['text']) return np.array(record['text'])
else: else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer))) raise RuntimeError(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".
format(type(self.tokenizer)))
def __len__(self): def __len__(self):
return len(self.records) return len(self.records)
...@@ -303,6 +324,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -303,6 +324,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
is_file_with_header(:obj:bool, `optional`, default to :obj: False) : is_file_with_header(:obj:bool, `optional`, default to :obj: False) :
Whether or not the file is with the header introduction. Whether or not the file is with the header introduction.
""" """
def __init__(self, def __init__(self,
base_path: str, base_path: str,
tokenizer: Union[PretrainedTokenizer, JiebaTokenizer], tokenizer: Union[PretrainedTokenizer, JiebaTokenizer],
...@@ -311,7 +333,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -311,7 +333,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
data_file: str = None, data_file: str = None,
label_file: str = None, label_file: str = None,
label_list: list = None, label_list: list = None,
split_char: str ="\002", split_char: str = "\002",
no_entity_label: str = "O", no_entity_label: str = "O",
ignore_label: int = -100, ignore_label: int = -100,
is_file_with_header: bool = False): is_file_with_header: bool = False):
...@@ -365,7 +387,15 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -365,7 +387,15 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
pad_token = self.tokenizer.pad_token pad_token = self.tokenizer.pad_token
tokens, labels = reseg_token_label(tokenizer=self.tokenizer, tokens=tokens, labels=labels) tokens, labels = reseg_token_label(tokenizer=self.tokenizer, tokens=tokens, labels=labels)
record = self.tokenizer.encode(text=tokens, max_seq_len=self.max_seq_len) if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
record = self.tokenizer.encode(text=tokens, max_seq_len=self.max_seq_len)
else:
record = self.tokenizer(
text=tokens,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=True,
is_split_into_words=True,
return_length=True)
elif isinstance(self.tokenizer, JiebaTokenizer): elif isinstance(self.tokenizer, JiebaTokenizer):
pad_token = self.tokenizer.vocab.pad_token pad_token = self.tokenizer.vocab.pad_token
...@@ -379,12 +409,13 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -379,12 +409,13 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
record = {'text': ids, 'seq_len': seq_len} record = {'text': ids, 'seq_len': seq_len}
else: else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer))) raise RuntimeError(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.format(type(self.tokenizer)))
if not record: if not record:
logger.info( logger.info(
"The text %s has been dropped as it has no words in the vocab after tokenization." "The text %s has been dropped as it has no words in the vocab after tokenization." % example.text_a)
% example.text_a)
continue continue
# convert labels into record # convert labels into record
...@@ -395,37 +426,46 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): ...@@ -395,37 +426,46 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
elif isinstance(self.tokenizer, JiebaTokenizer): elif isinstance(self.tokenizer, JiebaTokenizer):
tokens_with_specical_token = [self.tokenizer.vocab.to_tokens(id_) for id_ in record['text']] tokens_with_specical_token = [self.tokenizer.vocab.to_tokens(id_) for id_ in record['text']]
else: else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer))) raise RuntimeError(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer"
.format(type(self.tokenizer)))
tokens_index = 0 tokens_index = 0
for token in tokens_with_specical_token: for token in tokens_with_specical_token:
if tokens_index < len( if tokens_index < len(tokens) and token == tokens[tokens_index]:
tokens) and token == tokens[tokens_index]: record["label"].append(self.label_list.index(labels[tokens_index]))
record["label"].append(
self.label_list.index(labels[tokens_index]))
tokens_index += 1 tokens_index += 1
elif token in [pad_token]: elif token in [pad_token]:
record["label"].append(self.ignore_label) # label of special token record["label"].append(self.ignore_label) # label of special token
else: else:
record["label"].append( record["label"].append(self.label_list.index(self.no_entity_label))
self.label_list.index(self.no_entity_label))
records.append(record) records.append(record)
return records return records
def __getitem__(self, idx): def __getitem__(self, idx):
record = self.records[idx] record = self.records[idx]
if isinstance(self.tokenizer, PretrainedTokenizer): if isinstance(self.tokenizer, PretrainedTokenizer):
input_ids = np.array(record['input_ids'])
seq_lens = np.array(record['seq_len'])
if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
token_type_ids = np.array(record['token_type_ids'])
else:
token_type_ids = np.array(record['segment_ids'])
if 'label' in record.keys(): if 'label' in record.keys():
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['seq_len']), np.array(record['label'], dtype=np.int64) return input_ids, token_type_ids, seq_lens, np.array(record['label'], dtype=np.int64)
else: else:
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['seq_len']) return input_ids, token_type_ids, seq_lens
elif isinstance(self.tokenizer, JiebaTokenizer): elif isinstance(self.tokenizer, JiebaTokenizer):
if 'label' in record.keys(): if 'label' in record.keys():
return np.array(record['text']), np.array(record['seq_len']), np.array(record['label'], dtype=np.int64) return np.array(record['text']), np.array(record['seq_len']), np.array(record['label'], dtype=np.int64)
else: else:
return np.array(record['text']), np.array(record['seq_len']) return np.array(record['text']), np.array(record['seq_len'])
else: else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer))) raise RuntimeError(
"Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".
format(type(self.tokenizer)))
def __len__(self): def __len__(self):
return len(self.records) return len(self.records)
...@@ -11,9 +11,6 @@ ...@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# FIXME(zhangxuefei): remove this file after paddlenlp is released.
import copy import copy
import functools import functools
import inspect import inspect
...@@ -25,6 +22,7 @@ from typing import List, Tuple ...@@ -25,6 +22,7 @@ from typing import List, Tuple
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from packaging.version import Version
from paddle.dataset.common import DATA_HOME from paddle.dataset.common import DATA_HOME
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from paddlehub.module.module import serving, RunModule, runnable from paddlehub.module.module import serving, RunModule, runnable
...@@ -32,11 +30,11 @@ from paddlehub.module.module import serving, RunModule, runnable ...@@ -32,11 +30,11 @@ from paddlehub.module.module import serving, RunModule, runnable
from paddlehub.utils.log import logger from paddlehub.utils.log import logger
from paddlehub.utils.utils import reseg_token_label from paddlehub.utils.utils import reseg_token_label
import paddlenlp
from paddlenlp.embeddings.token_embedding import EMBEDDING_HOME, EMBEDDING_URL_ROOT from paddlenlp.embeddings.token_embedding import EMBEDDING_HOME, EMBEDDING_URL_ROOT
from paddlenlp.data import JiebaTokenizer from paddlenlp.data import JiebaTokenizer
from paddlehub.compat.module.nlp_module import DataFormatError from paddlehub.compat.module.nlp_module import DataFormatError
__all__ = [ __all__ = [
'PretrainedModel', 'PretrainedModel',
'register_base_model', 'register_base_model',
...@@ -357,14 +355,9 @@ class TextServing(object): ...@@ -357,14 +355,9 @@ class TextServing(object):
""" """
A base class for text model which supports serving. A base class for text model which supports serving.
""" """
@serving @serving
def predict_method( def predict_method(self, data: List[List[str]], max_seq_len: int = 128, batch_size: int = 1, use_gpu: bool = False):
self,
data: List[List[str]],
max_seq_len: int = 128,
batch_size: int = 1,
use_gpu: bool = False
):
""" """
Run predict method as a service. Run predict method as a service.
Serving as a task which is specified from serving config. Serving as a task which is specified from serving config.
...@@ -391,20 +384,16 @@ class TextServing(object): ...@@ -391,20 +384,16 @@ class TextServing(object):
if self.task == 'token-cls': if self.task == 'token-cls':
# remove labels of [CLS] token and pad tokens # remove labels of [CLS] token and pad tokens
results = [ results = [token_labels[1:len(data[i][0]) + 1] for i, token_labels in enumerate(results)]
token_labels[1:len(data[i][0])+1] for i, token_labels in enumerate(results)
]
return results return results
elif self.task is None: # embedding service elif self.task is None: # embedding service
results = self.get_embedding(data, use_gpu) results = self.get_embedding(data, use_gpu)
return results return results
else: # unknown service else: # unknown service
logger.error( logger.error(f'Unknown task {self.task}, current tasks supported:\n'
f'Unknown task {self.task}, current tasks supported:\n' '1. seq-cls: sequence classification service;\n'
'1. seq-cls: sequence classification service;\n' '2. token-cls: sequence labeling service;\n'
'2. token-cls: sequence labeling service;\n' '3. None: embedding service')
'3. None: embedding service'
)
return return
...@@ -422,11 +411,33 @@ class TransformerModule(RunModule, TextServing): ...@@ -422,11 +411,33 @@ class TransformerModule(RunModule, TextServing):
if self.task == 'token-cls': # Extra processing of token-cls task if self.task == 'token-cls': # Extra processing of token-cls task
tokens = text[0].split(split_char) tokens = text[0].split(split_char)
text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens) text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens)
is_split_into_words = True
else:
is_split_into_words = False
if len(text) == 1: if len(text) == 1:
encoded_inputs = tokenizer.encode(text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
encoded_inputs = tokenizer.encode(
text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
else:
encoded_inputs = tokenizer(
text=text[0],
max_seq_len=max_seq_len,
pad_to_max_seq_len=True,
is_split_into_words=is_split_into_words,
return_length=True)
elif len(text) == 2: elif len(text) == 2:
encoded_inputs = tokenizer.encode(text[0], text_pair=text[1], max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
encoded_inputs = tokenizer.encode(
text[0], text_pair=text[1], max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
else:
encoded_inputs = tokenizer(
text=text[0],
text_pair=text[1],
max_seq_len=max_seq_len,
pad_to_max_seq_len=True,
is_split_into_words=is_split_into_words,
return_length=True)
else: else:
raise RuntimeError( raise RuntimeError(
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text))
...@@ -442,7 +453,14 @@ class TransformerModule(RunModule, TextServing): ...@@ -442,7 +453,14 @@ class TransformerModule(RunModule, TextServing):
examples = [] examples = []
for text in data: for text in data:
encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char) encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char)
examples.append((encoded_inputs['input_ids'], encoded_inputs['segment_ids'])) input_ids = encoded_inputs['input_ids']
if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
token_type_ids = encoded_inputs['token_type_ids']
else:
token_type_ids = encoded_inputs['segment_ids']
examples.append((input_ids, token_type_ids))
# Seperates data into some batches. # Seperates data into some batches.
one_batch = [] one_batch = []
...@@ -468,7 +486,8 @@ class TransformerModule(RunModule, TextServing): ...@@ -468,7 +486,8 @@ class TransformerModule(RunModule, TextServing):
if self.task == 'seq-cls': if self.task == 'seq-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls': elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) predictions, avg_loss, metric = self(
input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset() self.metric.reset()
return {'loss': avg_loss, 'metrics': metric} return {'loss': avg_loss, 'metrics': metric}
...@@ -485,7 +504,8 @@ class TransformerModule(RunModule, TextServing): ...@@ -485,7 +504,8 @@ class TransformerModule(RunModule, TextServing):
if self.task == 'seq-cls': if self.task == 'seq-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2]) predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], labels=batch[2])
elif self.task == 'token-cls': elif self.task == 'token-cls':
predictions, avg_loss, metric = self(input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) predictions, avg_loss, metric = self(
input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset() self.metric.reset()
return {'metrics': metric} return {'metrics': metric}
...@@ -502,20 +522,14 @@ class TransformerModule(RunModule, TextServing): ...@@ -502,20 +522,14 @@ class TransformerModule(RunModule, TextServing):
if self.task is not None: if self.task is not None:
raise RuntimeError("The get_embedding method is only valid when task is None, but got task %s" % self.task) raise RuntimeError("The get_embedding method is only valid when task is None, but got task %s" % self.task)
return self.predict( return self.predict(data=data, use_gpu=use_gpu)
data=data,
use_gpu=use_gpu
)
def predict(
self,
data: List[List[str]],
max_seq_len: int = 128,
split_char: str = '\002',
batch_size: int = 1,
use_gpu: bool = False
):
def predict(self,
data: List[List[str]],
max_seq_len: int = 128,
split_char: str = '\002',
batch_size: int = 1,
use_gpu: bool = False):
""" """
Predicts the data labels. Predicts the data labels.
...@@ -532,12 +546,10 @@ class TransformerModule(RunModule, TextServing): ...@@ -532,12 +546,10 @@ class TransformerModule(RunModule, TextServing):
""" """
if self.task not in self._tasks_supported \ if self.task not in self._tasks_supported \
and self.task is not None: # None for getting embedding and self.task is not None: # None for getting embedding
raise RuntimeError( raise RuntimeError(f'Unknown task {self.task}, current tasks supported:\n'
f'Unknown task {self.task}, current tasks supported:\n' '1. seq-cls: sequence classification;\n'
'1. seq-cls: sequence classification;\n' '2. token-cls: sequence labeling;\n'
'2. token-cls: sequence labeling;\n' '3. None: embedding')
'3. None: embedding'
)
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
...@@ -563,10 +575,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -563,10 +575,7 @@ class TransformerModule(RunModule, TextServing):
results.extend(token_labels) results.extend(token_labels)
elif self.task == None: elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids) sequence_output, pooled_output = self(input_ids, segment_ids)
results.append([ results.append([pooled_output.squeeze(0).numpy().tolist(), sequence_output.squeeze(0).numpy().tolist()])
pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()
])
return results return results
...@@ -575,6 +584,7 @@ class EmbeddingServing(object): ...@@ -575,6 +584,7 @@ class EmbeddingServing(object):
""" """
A base class for embedding model which supports serving. A base class for embedding model which supports serving.
""" """
@serving @serving
def calc_similarity(self, data: List[List[str]]): def calc_similarity(self, data: List[List[str]]):
""" """
...@@ -593,8 +603,7 @@ class EmbeddingServing(object): ...@@ -593,8 +603,7 @@ class EmbeddingServing(object):
for word in word_pair: for word in word_pair:
if self.get_idx_from_word(word) == \ if self.get_idx_from_word(word) == \
self.get_idx_from_word(self.vocab.unk_token): self.get_idx_from_word(self.vocab.unk_token):
raise RuntimeError( raise RuntimeError(f'Word "{word}" is not in vocab. Please check your inputs.')
f'Word "{word}" is not in vocab. Please check your inputs.')
results.append(str(self.cosine_sim(*word_pair))) results.append(str(self.cosine_sim(*word_pair)))
return results return results
...@@ -627,5 +636,5 @@ class EmbeddingModule(RunModule, EmbeddingServing): ...@@ -627,5 +636,5 @@ class EmbeddingModule(RunModule, EmbeddingServing):
""" """
if self.embedding_name.endswith('.en'): # English if self.embedding_name.endswith('.en'): # English
raise NotImplementedError # TODO: (chenxiaojie) add tokenizer of English embedding raise NotImplementedError # TODO: (chenxiaojie) add tokenizer of English embedding
else: # Chinese else: # Chinese
return JiebaTokenizer(self.vocab) return JiebaTokenizer(self.vocab)
...@@ -336,12 +336,11 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None): ...@@ -336,12 +336,11 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
''' '''
if labels: if labels:
if len(tokens) != len(labels): if len(tokens) != len(labels):
raise ValueError( raise ValueError("The length of tokens must be same with labels")
"The length of tokens must be same with labels")
ret_tokens = [] ret_tokens = []
ret_labels = [] ret_labels = []
for token, label in zip(tokens, labels): for token, label in zip(tokens, labels):
sub_token = tokenizer(token) sub_token = tokenizer._tokenize(token)
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -354,13 +353,12 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None): ...@@ -354,13 +353,12 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
ret_labels.extend([sub_label] * (len(sub_token) - 1)) ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(ret_labels): if len(ret_tokens) != len(ret_labels):
raise ValueError( raise ValueError("The length of ret_tokens can't match with labels")
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels return ret_tokens, ret_labels
else: else:
ret_tokens = [] ret_tokens = []
for token in tokens: for token in tokens:
sub_token = tokenizer(token) sub_token = tokenizer._tokenize(token)
if len(sub_token) == 0: if len(sub_token) == 0:
continue continue
ret_tokens.extend(sub_token) ret_tokens.extend(sub_token)
...@@ -376,7 +374,7 @@ def pad_sequence(ids: List[int], max_seq_len: int, pad_token_id: int): ...@@ -376,7 +374,7 @@ def pad_sequence(ids: List[int], max_seq_len: int, pad_token_id: int):
assert len(ids) <= max_seq_len, \ assert len(ids) <= max_seq_len, \
f'The input length {len(ids)} is greater than max_seq_len {max_seq_len}. '\ f'The input length {len(ids)} is greater than max_seq_len {max_seq_len}. '\
'Please check the input list and max_seq_len if you really want to pad a sequence.' 'Please check the input list and max_seq_len if you really want to pad a sequence.'
return ids[:] + [pad_token_id]*(max_seq_len-len(ids)) return ids[:] + [pad_token_id] * (max_seq_len - len(ids))
def trunc_sequence(ids: List[int], max_seq_len: int): def trunc_sequence(ids: List[int], max_seq_len: int):
......
...@@ -16,4 +16,4 @@ tqdm ...@@ -16,4 +16,4 @@ tqdm
visualdl >= 2.0.0 visualdl >= 2.0.0
# gunicorn not support windows # gunicorn not support windows
gunicorn >= 19.10.0; sys_platform != "win32" gunicorn >= 19.10.0; sys_platform != "win32"
paddlenlp >= 2.0.0b2 paddlenlp >= 2.0.0rc5
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册