未验证 提交 045e4e22 编写于 作者: K KP 提交者: GitHub

Add embedding finetune demo (#1204)

* Add embedding seq-cls finetune demo and update api

* Update docs of pad_sequence and trunc_sequence
上级 5832b1a1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlenlp as nlp
from paddlenlp.embeddings import TokenEmbedding
from paddlenlp.data import JiebaTokenizer
from paddlehub.utils.log import logger
from paddlehub.utils.utils import pad_sequence, trunc_sequence
class BoWModel(nn.Layer):
"""
This class implements the Bag of Words Classification Network model to classify texts.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 300): The embedding dimension.
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
num_classes: int = 2,
embedder: TokenEmbedding = None,
tokenizer: JiebaTokenizer = None,
hidden_size: int = 128,
fc_hidden_size: int = 96,
load_checkpoint: str = None,
label_map: dict = None):
super().__init__()
self.embedder = embedder
self.tokenizer = tokenizer
self.label_map = label_map
emb_dim = self.embedder.embedding_dim
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size)
self.fc2 = nn.Linear(hidden_size, fc_hidden_size)
self.dropout = nn.Dropout(p=0.3, axis=1)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
self.criterion = nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
if load_checkpoint is not None and os.path.isfile(load_checkpoint):
state_dict = paddle.load(load_checkpoint)
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
"""
One step for training, which should be called as forward computation.
Args:
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
such as input_ids, sent_ids, pos_ids, input_mask and labels.
batch_idx(int): The index of batch.
Returns:
results(:obj: Dict) : The model outputs, such as loss and metrics.
"""
_, avg_loss, metric = self(ids=batch[0], labels=batch[1])
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric}
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
"""
One step for validation, which should be called as forward computation.
Args:
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
such as input_ids, sent_ids, pos_ids, input_mask and labels.
batch_idx(int): The index of batch.
Returns:
results(:obj: Dict) : The model outputs, such as metrics.
"""
_, _, metric = self(ids=batch[0], labels=batch[1])
self.metric.reset()
return {'metrics': metric}
def forward(self, ids: paddle.Tensor, labels: paddle.Tensor = None):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(ids)
# Shape: (batch_size, embedding_dim)
summed = self.bow_encoder(embedded_text)
summed = self.dropout(summed)
encoded_text = paddle.tanh(summed)
# Shape: (batch_size, hidden_size)
fc1_out = paddle.tanh(self.fc1(encoded_text))
# Shape: (batch_size, fc_hidden_size)
fc2_out = paddle.tanh(self.fc2(fc1_out))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc2_out)
probs = F.softmax(logits, axis=1)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
else:
return probs
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int):
examples = []
for item in data:
ids = self.tokenizer.encode(sentence=item[0])
if len(ids) > max_seq_len:
ids = trunc_sequence(ids, max_seq_len)
else:
pad_token = self.tokenizer.vocab.pad_token
pad_token_id = self.tokenizer.vocab.to_indices(pad_token)
ids = pad_sequence(ids, max_seq_len, pad_token_id)
examples.append(ids)
# Seperates data into some batches.
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
yield one_batch
one_batch = []
if one_batch:
# The last batch whose size is less than the config batch_size setting.
yield one_batch
def predict(
self,
data: List[List[str]],
max_seq_len: int = 128,
batch_size: int = 1,
use_gpu: bool = False,
return_result: bool = True,
):
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
batches = self._batchify(data, max_seq_len, batch_size)
results = []
self.eval()
for batch in batches:
ids = paddle.to_tensor(batch)
probs = self(ids)
idx = paddle.argmax(probs, axis=1).numpy()
if return_result:
idx = idx.tolist()
labels = [self.label_map[i] for i in idx]
results.extend(labels)
else:
results.extend(probs.numpy())
return results
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlehub as hub
from paddlenlp.data import JiebaTokenizer
from model import BoWModel
import ast
import argparse
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint", type=str, default='./checkpoint/best_model/model.pdparams', help="Model checkpoint")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")
args = parser.parse_args()
if __name__ == '__main__':
# Data to be prdicted
data = [
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"],
["交通方便;环境很好;服务态度很好 房间较小"],
["还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"],
["前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"],
["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"],
]
label_map = {0: 'negative', 1: 'positive'}
embedder = hub.Module(name=args.hub_embedding_name)
tokenizer = embedder.get_tokenizer()
model = BoWModel(
embedder=embedder,
tokenizer=tokenizer,
load_checkpoint=args.checkpoint,
label_map=label_map)
results = model.predict(data, max_seq_len=args.max_seq_len, batch_size=args.batch_size, use_gpu=args.use_gpu, return_result=False)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddlehub as hub
from paddlehub.datasets import ChnSentiCorp
from paddlenlp.data import JiebaTokenizer
from model import BoWModel
import ast
import argparse
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate used to train with warmup.")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint")
parser.add_argument("--save_interval", type=int, default=5, help="Save checkpoint every n epoch.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")
args = parser.parse_args()
if __name__ == '__main__':
embedder = hub.Module(name=args.hub_embedding_name)
tokenizer = embedder.get_tokenizer()
train_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train')
dev_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev')
test_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test')
model = BoWModel(embedder=embedder)
optimizer = paddle.optimizer.AdamW(
learning_rate=args.learning_rate, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
trainer.train(
train_dataset,
epochs=args.num_epoch,
batch_size=args.batch_size,
eval_dataset=dev_dataset,
save_interval=args.save_interval,
)
trainer.evaluate(test_dataset, batch_size=args.batch_size)
......@@ -15,6 +15,7 @@
from typing import List
from paddlenlp.embeddings import TokenEmbedding
from paddlehub.module.module import moduleinfo, serving
from paddlehub.module.nlp_module import EmbeddingModule
@moduleinfo(
......@@ -23,33 +24,13 @@ from paddlehub.module.module import moduleinfo, serving
summary="",
author="paddlepaddle",
author_email="",
type="nlp/semantic_model")
type="nlp/semantic_model",
meta=EmbeddingModule)
class Embedding(TokenEmbedding):
"""
Embedding model
"""
def __init__(self, *args, **kwargs):
super(Embedding, self).__init__(embedding_name="w2v.baidu_encyclopedia.target.word-word.dim300", *args, **kwargs)
embedding_name = 'w2v.baidu_encyclopedia.target.word-word.dim300'
@serving
def calc_similarity(self, data: List[List[str]]):
"""
Calculate similarities of giving word pairs.
"""
results = []
for word_pair in data:
if len(word_pair) != 2:
raise RuntimeError(
f'The input must have two words, but got {len(word_pair)}. Please check your inputs.')
if not isinstance(word_pair[0], str) or not isinstance(word_pair[1], str):
raise RuntimeError(
f'The types of text pair must be (str, str), but got'
f' ({type(word_pair[0]).__name__}, {type(word_pair[1]).__name__}). Please check your inputs.')
for word in word_pair:
if self.get_idx_from_word(word) == \
self.get_idx_from_word(self.vocab.unk_token):
raise RuntimeError(
f'Word "{word}" is not in vocab. Please check your inputs.')
results.append(str(self.cosine_sim(*word_pair)))
return results
def __init__(self, *args, **kwargs):
super(Embedding, self).__init__(embedding_name=self.embedding_name, *args, **kwargs)
\ No newline at end of file
......@@ -20,13 +20,14 @@ import numpy as np
import paddle
from paddlehub.env import DATA_HOME
from paddlehub.text.bert_tokenizer import BertTokenizer
from paddlehub.text.tokenizer import CustomTokenizer
from paddlenlp.transformers import PretrainedTokenizer
from paddlenlp.data import JiebaTokenizer
from paddlehub.utils.log import logger
from paddlehub.utils.utils import download, reseg_token_label
from paddlehub.utils.utils import download, reseg_token_label, pad_sequence, trunc_sequence
from paddlehub.utils.xarfile import is_xarfile, unarchive
class InputExample(object):
"""
The input data structure of Transformer modules (BERT, ERNIE and so on).
......@@ -72,7 +73,7 @@ class BaseNLPDataset(object):
def __init__(self,
base_path: str,
tokenizer: Union[BertTokenizer, CustomTokenizer],
tokenizer: Union[PretrainedTokenizer, JiebaTokenizer],
max_seq_len: Optional[int] = 128,
mode: Optional[str] = "train",
data_file: Optional[str] = None,
......@@ -81,7 +82,7 @@ class BaseNLPDataset(object):
"""
Ags:
base_path (:obj:`str`): The directory to the whole dataset.
tokenizer (:obj:`BertTokenizer` or :obj:`CustomTokenizer`):
tokenizer (:obj:`PretrainedTokenizer` or :obj:`JiebaTokenizer`):
It tokenizes the text and encodes the data as model needed.
max_seq_len (:obj:`int`, `optional`, defaults to :128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
......@@ -159,7 +160,7 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
def __init__(self,
base_path: str,
tokenizer: Union[BertTokenizer, CustomTokenizer],
tokenizer: Union[PretrainedTokenizer, JiebaTokenizer],
max_seq_len: int = 128,
mode: str = "train",
data_file: str = None,
......@@ -169,7 +170,7 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
"""
Ags:
base_path (:obj:`str`): The directory to the whole dataset.
tokenizer (:obj:`BertTokenizer` or :obj:`CustomTokenizer`):
tokenizer (:obj:`PretrainedTokenizer` or :obj:`JiebaTokenizer`):
It tokenizes the text and encodes the data as model needed.
max_seq_len (:obj:`int`, `optional`, defaults to :128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
......@@ -231,9 +232,22 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
"""
records = []
for example in examples:
if isinstance(self.tokenizer, PretrainedTokenizer):
record = self.tokenizer.encode(text=example.text_a, text_pair=example.text_b, max_seq_len=self.max_seq_len)
# CustomTokenizer will tokenize the text firstly and then lookup words in the vocab
# When all words are not found in the vocab, the text will be dropped.
elif isinstance(self.tokenizer, JiebaTokenizer):
pad_token = self.tokenizer.vocab.pad_token
ids = self.tokenizer.encode(sentence=example.text_a)
seq_len = min(len(ids), self.max_seq_len)
if len(ids) > self.max_seq_len:
ids = trunc_sequence(ids, self.max_seq_len)
else:
pad_token_id = self.tokenizer.vocab.to_indices(pad_token)
ids = pad_sequence(ids, self.max_seq_len, pad_token_id)
record = {'text': ids, 'seq_len': seq_len}
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer)))
if not record:
logger.info(
"The text %s has been dropped as it has no words in the vocab after tokenization." % example.text_a)
......@@ -245,19 +259,53 @@ class TextClassificationDataset(BaseNLPDataset, paddle.io.Dataset):
def __getitem__(self, idx):
record = self.records[idx]
if isinstance(self.tokenizer, PretrainedTokenizer):
if 'label' in record.keys():
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['label'], dtype=np.int64)
else:
return np.array(record['input_ids']), np.array(record['segment_ids'])
elif isinstance(self.tokenizer, JiebaTokenizer):
if 'label' in record.keys():
return np.array(record['text']), np.array(record['label'], dtype=np.int64)
else:
return np.array(record['text'])
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer)))
def __len__(self):
return len(self.records)
class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
"""
Ags:
base_path (:obj:`str`): The directory to the whole dataset.
tokenizer (:obj:`PretrainedTokenizer` or :obj:`JiebaTokenizer`):
It tokenizes the text and encodes the data as model needed.
max_seq_len (:obj:`int`, `optional`, defaults to :128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train, test or dev).
data_file(:obj:`str`, `optional`, defaults to :obj:`None`):
The data file name, which is relative to the base_path.
label_file(:obj:`str`, `optional`, defaults to :obj:`None`):
The label file name, which is relative to the base_path.
It is all labels of the dataset, one line one label.
label_list(:obj:`List[str]`, `optional`, defaults to :obj:`None`):
The list of all labels of the dataset
split_char(:obj:`str`, `optional`, defaults to :obj:`\002`):
The symbol used to split chars in text and labels
no_entity_label(:obj:`str`, `optional`, defaults to :obj:`O`):
The label used to mark no entities
ignore_label(:obj:`int`, `optional`, defaults to :-100):
If one token's label == ignore_label, it will be ignored when
calculating loss
is_file_with_header(:obj:bool, `optional`, default to :obj: False) :
Whether or not the file is with the header introduction.
"""
def __init__(self,
base_path: str,
tokenizer: Union[BertTokenizer, CustomTokenizer],
tokenizer: Union[PretrainedTokenizer, JiebaTokenizer],
max_seq_len: int = 128,
mode: str = "train",
data_file: str = None,
......@@ -309,22 +357,46 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
"""
records = []
for example in examples:
tokens, labels = reseg_token_label(
tokenizer=self.tokenizer,
tokens=example.text_a.split(self.split_char),
labels=example.label.split(self.split_char))
record = self.tokenizer.encode(
text=tokens, max_seq_len=self.max_seq_len)
# CustomTokenizer will tokenize the text firstly and then lookup words in the vocab
# When all words are not found in the vocab, the text will be dropped.
tokens = example.text_a.split(self.split_char)
labels = example.label.split(self.split_char)
# convert tokens into record
if isinstance(self.tokenizer, PretrainedTokenizer):
pad_token = self.tokenizer.pad_token
tokens, labels = reseg_token_label(tokenizer=self.tokenizer, tokens=tokens, labels=labels)
record = self.tokenizer.encode(text=tokens, max_seq_len=self.max_seq_len)
elif isinstance(self.tokenizer, JiebaTokenizer):
pad_token = self.tokenizer.vocab.pad_token
ids = [self.tokenizer.vocab.to_indices(token) for token in tokens]
seq_len = min(len(ids), self.max_seq_len)
if len(ids) > self.max_seq_len:
ids = trunc_sequence(ids, self.max_seq_len)
else:
pad_token_id = self.tokenizer.vocab.to_indices(pad_token)
ids = pad_sequence(ids, self.max_seq_len, pad_token_id)
record = {'text': ids, 'seq_len': seq_len}
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer)))
if not record:
logger.info(
"The text %s has been dropped as it has no words in the vocab after tokenization."
% example.text_a)
continue
# convert labels into record
if labels:
record["label"] = []
if isinstance(self.tokenizer, PretrainedTokenizer):
tokens_with_specical_token = self.tokenizer.convert_ids_to_tokens(record['input_ids'])
elif isinstance(self.tokenizer, JiebaTokenizer):
tokens_with_specical_token = [self.tokenizer.vocab.to_tokens(id_) for id_ in record['text']]
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer)))
tokens_index = 0
for token in tokens_with_specical_token:
if tokens_index < len(
......@@ -332,7 +404,7 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
record["label"].append(
self.label_list.index(labels[tokens_index]))
tokens_index += 1
elif token in [self.tokenizer.pad_token]:
elif token in [pad_token]:
record["label"].append(self.ignore_label) # label of special token
else:
record["label"].append(
......@@ -342,10 +414,18 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
def __getitem__(self, idx):
record = self.records[idx]
if isinstance(self.tokenizer, PretrainedTokenizer):
if 'label' in record.keys():
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['seq_len']), np.array(record['label'], dtype=np.int64)
else:
return np.array(record['input_ids']), np.array(record['segment_ids']), np.array(record['seq_len'])
elif isinstance(self.tokenizer, JiebaTokenizer):
if 'label' in record.keys():
return np.array(record['text']), np.array(record['seq_len']), np.array(record['label'], dtype=np.int64)
else:
return np.array(record['text']), np.array(record['seq_len'])
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer or JiebaTokenizer".format(type(self.tokenizer)))
def __len__(self):
return len(self.records)
......@@ -32,6 +32,9 @@ from paddlehub.module.module import serving, RunModule, runnable
from paddlehub.utils.log import logger
from paddlehub.utils.utils import reseg_token_label
from paddlenlp.embeddings.token_embedding import EMBEDDING_HOME, EMBEDDING_URL_ROOT
from paddlenlp.data import JiebaTokenizer
__all__ = [
'PretrainedModel',
'register_base_model',
......@@ -510,6 +513,7 @@ class TransformerModule(RunModule, TextServing):
batch_size: int = 1,
use_gpu: bool = False
):
"""
Predicts the data labels.
......@@ -563,3 +567,63 @@ class TransformerModule(RunModule, TextServing):
])
return results
class EmbeddingServing(object):
"""
A base class for embedding model which supports serving.
"""
@serving
def calc_similarity(self, data: List[List[str]]):
"""
Calculate similarities of giving word pairs.
"""
results = []
for word_pair in data:
if len(word_pair) != 2:
raise RuntimeError(
f'The input must have two words, but got {len(word_pair)}. Please check your inputs.')
if not isinstance(word_pair[0], str) or not isinstance(word_pair[1], str):
raise RuntimeError(
f'The types of text pair must be (str, str), but got'
f' ({type(word_pair[0]).__name__}, {type(word_pair[1]).__name__}). Please check your inputs.')
for word in word_pair:
if self.get_idx_from_word(word) == \
self.get_idx_from_word(self.vocab.unk_token):
raise RuntimeError(
f'Word "{word}" is not in vocab. Please check your inputs.')
results.append(str(self.cosine_sim(*word_pair)))
return results
class EmbeddingModule(RunModule, EmbeddingServing):
"""
The base class for Embedding models.
"""
base_url = 'https://paddlenlp.bj.bcebos.com/models/embeddings/'
def _download_vocab(self):
"""
Download vocab from url
"""
url = EMBEDDING_URL_ROOT + '/' + f'vocab.{self.embedding_name}'
get_path_from_url(url, EMBEDDING_HOME)
def get_vocab_path(self):
"""
Get local vocab path
"""
vocab_path = os.path.join(EMBEDDING_HOME, f'vocab.{self.embedding_name}')
if not os.path.exists(vocab_path):
self._download_vocab()
return vocab_path
def get_tokenizer(self, *args, **kwargs):
"""
Get tokenizer of embedding module
"""
if self.embedding_name.endswith('.en'): # English
raise NotImplementedError # TODO: (chenxiaojie) add tokenizer of English embedding
else: # Chinese
return JiebaTokenizer(self.vocab)
......@@ -367,3 +367,23 @@ def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
if len(sub_token) < 2:
continue
return ret_tokens, None
def pad_sequence(ids: List[int], max_seq_len: int, pad_token_id: int):
'''
Pads a sequence to max_seq_len
'''
assert len(ids) <= max_seq_len, \
f'The input length {len(ids)} is greater than max_seq_len {max_seq_len}. '\
'Please check the input list and max_seq_len if you really want to pad a sequence.'
return ids[:] + [pad_token_id]*(max_seq_len-len(ids))
def trunc_sequence(ids: List[int], max_seq_len: int):
'''
Truncates a sequence to max_seq_len
'''
assert len(ids) >= max_seq_len, \
f'The input length {len(ids)} is less than max_seq_len {max_seq_len}. ' \
'Please check the input list and max_seq_len if you really want to truncate a sequence.'
return ids[:max_seq_len]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册