From 21fc5cb54960df34040f1bc92c2dbddb78b1fa55 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 11 Mar 2021 09:57:25 +0800 Subject: [PATCH] Add dataset, module task, and demo of text-matching (#1281) * Add dataset, module task, and demo of text-matching --- demo/text_matching/predict.py | 34 ++++ demo/text_matching/train.py | 51 ++++++ .../text/language_model/ernie_tiny/module.py | 61 +++++++- paddlehub/datasets/__init__.py | 1 + paddlehub/datasets/base_nlp_dataset.py | 117 ++++++++++++++ paddlehub/datasets/lcqmc.py | 60 +++++++ paddlehub/module/nlp_module.py | 148 +++++++++++------- 7 files changed, 411 insertions(+), 61 deletions(-) create mode 100644 demo/text_matching/predict.py create mode 100644 demo/text_matching/train.py create mode 100644 paddlehub/datasets/lcqmc.py diff --git a/demo/text_matching/predict.py b/demo/text_matching/predict.py new file mode 100644 index 00000000..6fe6a42a --- /dev/null +++ b/demo/text_matching/predict.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddlehub as hub + +if __name__ == '__main__': + data = [ + ['这个表情叫什么', '这个猫的表情叫什么'], + ['什么是智能手环', '智能手环有什么用'], + ['介绍几本好看的都市异能小说,要完结的!', '求一本好看点的都市异能小说,要完结的'], + ['一只蜜蜂落在日历上(打一成语)', '一只蜜蜂停在日历上(猜一成语)'], + ['一盒香烟不拆开能存放多久?', '一条没拆封的香烟能存放多久。'], + ] + label_map = {0: 'similar', 1: 'dissimilar'} + + model = hub.Module( + name='ernie_tiny', + version='2.0.1', + task='text-matching', + load_checkpoint='./checkpoint/best_model/model.pdparams', + label_map=label_map) + results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=True) + for idx, texts in enumerate(data): + print('TextA: {}\tTextB: {}\t Label: {}'.format(texts[0], texts[1], results[idx])) diff --git a/demo/text_matching/train.py b/demo/text_matching/train.py new file mode 100644 index 00000000..7770b3c0 --- /dev/null +++ b/demo/text_matching/train.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddlehub as hub +from paddlehub.datasets import LCQMC + +import ast +import argparse + +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False") +parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") +parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.") +parser.add_argument("--batch_size", type=int, default=128, help="Total examples' number in batch for training.") +parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint") +parser.add_argument("--save_interval", type=int, default=2, help="Save checkpoint every n epoch.") + +args = parser.parse_args() + +if __name__ == '__main__': + model = hub.Module(name='ernie_tiny', version='2.0.1', task='text-matching') + tokenizer = model.get_tokenizer() + + train_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train') + dev_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev') + test_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test') + + optimizer = paddle.optimizer.AdamW( + learning_rate=args.learning_rate, parameters=model.parameters()) + trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu) + trainer.train( + train_dataset, + epochs=args.num_epoch, + batch_size=args.batch_size, + eval_dataset=dev_dataset, + save_interval=args.save_interval, + ) + trainer.evaluate(test_dataset, batch_size=args.batch_size) diff --git a/modules/text/language_model/ernie_tiny/module.py b/modules/text/language_model/ernie_tiny/module.py index d309ac47..a85bc320 100644 --- a/modules/text/language_model/ernie_tiny/module.py +++ b/modules/text/language_model/ernie_tiny/module.py @@ -29,7 +29,7 @@ from paddlehub.utils.log import logger @moduleinfo( name="ernie_tiny", - version="2.0.1", + version="2.0.2", summary="Baidu's ERNIE-tiny, Enhanced Representation through kNowledge IntEgration, tiny version, max_seq_len=512", author="paddlepaddle", author_email="", @@ -71,6 +71,12 @@ class ErnieTiny(nn.Layer): self.metric = ChunkEvaluator( label_list=[self.label_map[i] for i in sorted(self.label_map.keys())] ) + elif task == 'text-matching': + self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-tiny', **kwargs) + self.dropout = paddle.nn.Dropout(0.1) + self.classifier = paddle.nn.Linear(self.model.config['hidden_size']*3, 2) + self.criterion = paddle.nn.loss.CrossEntropyLoss() + self.metric = paddle.metric.Accuracy() elif task is None: self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-tiny', **kwargs) else: @@ -84,8 +90,28 @@ class ErnieTiny(nn.Layer): self.set_state_dict(state_dict) logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint)) - def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None): - result = self.model(input_ids, token_type_ids, position_ids, attention_mask) + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + attention_mask=None, + query_input_ids=None, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_input_ids=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None, + seq_lengths=None, + labels=None): + + if self.task != 'text-matching': + result = self.model(input_ids, token_type_ids, position_ids, attention_mask) + else: + query_result = self.model(query_input_ids, query_token_type_ids, query_position_ids, query_attention_mask) + title_result = self.model(title_input_ids, title_token_type_ids, title_position_ids, title_attention_mask) + if self.task == 'seq-cls': logits = result probs = F.softmax(logits, axis=1) @@ -108,6 +134,35 @@ class ErnieTiny(nn.Layer): _, _, f1_score = map(float, self.metric.accumulate()) return token_level_probs, loss, {'f1_score': f1_score} return token_level_probs + elif self.task == 'text-matching': + query_token_embedding, _ = query_result + query_token_embedding = self.dropout(query_token_embedding) + query_attention_mask = paddle.unsqueeze( + (query_input_ids != self.model.pad_token_id).astype(self.model.pooler.dense.weight.dtype), axis=2) + query_token_embedding = query_token_embedding * query_attention_mask + query_sum_embedding = paddle.sum(query_token_embedding, axis=1) + query_sum_mask = paddle.sum(query_attention_mask, axis=1) + query_mean = query_sum_embedding / query_sum_mask + + title_token_embedding, _ = title_result + title_token_embedding = self.dropout(title_token_embedding) + title_attention_mask = paddle.unsqueeze( + (title_input_ids != self.model.pad_token_id).astype(self.model.pooler.dense.weight.dtype), axis=2) + title_token_embedding = title_token_embedding * title_attention_mask + title_sum_embedding = paddle.sum(title_token_embedding, axis=1) + title_sum_mask = paddle.sum(title_attention_mask, axis=1) + title_mean = title_sum_embedding / title_sum_mask + + sub = paddle.abs(paddle.subtract(query_mean, title_mean)) + projection = paddle.concat([query_mean, title_mean, sub], axis=-1) + logits = self.classifier(projection) + probs = F.softmax(logits) + if labels is not None: + loss = self.criterion(logits, labels) + correct = self.metric.compute(probs, labels) + acc = self.metric.update(correct) + return probs, loss, {'acc': acc} + return probs else: sequence_output, pooled_output = result return sequence_output, pooled_output diff --git a/paddlehub/datasets/__init__.py b/paddlehub/datasets/__init__.py index 12447233..4f097c2f 100644 --- a/paddlehub/datasets/__init__.py +++ b/paddlehub/datasets/__init__.py @@ -17,3 +17,4 @@ from paddlehub.datasets.flowers import Flowers from paddlehub.datasets.minicoco import MiniCOCO from paddlehub.datasets.chnsenticorp import ChnSentiCorp from paddlehub.datasets.msra_ner import MSRA_NER +from paddlehub.datasets.lcqmc import LCQMC diff --git a/paddlehub/datasets/base_nlp_dataset.py b/paddlehub/datasets/base_nlp_dataset.py index c9e425ad..a98fd93e 100644 --- a/paddlehub/datasets/base_nlp_dataset.py +++ b/paddlehub/datasets/base_nlp_dataset.py @@ -469,3 +469,120 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset): def __len__(self): return len(self.records) + + +class TextMatchingDataset(BaseNLPDataset, paddle.io.Dataset): + """ + The dataset class which is fit for all datatset of text matching. + """ + + def __init__(self, + base_path: str, + tokenizer: PretrainedTokenizer, + max_seq_len: int = 128, + mode: str = "train", + data_file: str = None, + label_file: str = None, + label_list: list = None, + is_file_with_header: bool = False): + """ + Ags: + base_path (:obj:`str`): The directory to the whole dataset. + tokenizer (:obj:`PretrainedTokenizer`): + It tokenizes the text and encodes the data as model needed. + max_seq_len (:obj:`int`, `optional`, defaults to :128): + If set to a number, will limit the total sequence returned so that it has a maximum length. + mode (:obj:`str`, `optional`, defaults to `train`): + It identifies the dataset mode (train, test or dev). + data_file(:obj:`str`, `optional`, defaults to :obj:`None`): + The data file name, which is relative to the base_path. + label_file(:obj:`str`, `optional`, defaults to :obj:`None`): + The label file name, which is relative to the base_path. + It is all labels of the dataset, one line one label. + label_list(:obj:`List[str]`, `optional`, defaults to :obj:`None`): + The list of all labels of the dataset + is_file_with_header(:obj:bool, `optional`, default to :obj: False) : + Whether or not the file is with the header introduction. + """ + super(TextMatchingDataset, self).__init__( + base_path=base_path, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + mode=mode, + data_file=data_file, + label_file=label_file, + label_list=label_list) + self.examples = self._read_file(self.data_file, is_file_with_header) + + self.records = self._convert_examples_to_records(self.examples) + + def _read_file(self, input_file, is_file_with_header: bool = False) -> List[InputExample]: + """ + Reads a tab separated value file. + Args: + input_file (:obj:str) : The file to be read. + is_file_with_header(:obj:bool, `optional`, default to :obj: False) : + Whether or not the file is with the header introduction. + Returns: + examples (:obj:`List[InputExample]`): All the input data. + """ + if not os.path.exists(input_file): + raise RuntimeError("The file {} is not found.".format(input_file)) + else: + with io.open(input_file, "r", encoding="UTF-8") as f: + reader = csv.reader(f, delimiter="\t", quotechar=None) + examples = [] + seq_id = 0 + header = next(reader) if is_file_with_header else None + for line in reader: + example = InputExample(guid=seq_id, text_a=line[0], text_b=line[1], label=line[2]) + seq_id += 1 + examples.append(example) + return examples + + def _convert_examples_to_records(self, examples: List[InputExample]) -> List[dict]: + """ + Converts all examples to records which the model needs. + Args: + examples(obj:`List[InputExample]`): All data examples returned by _read_file. + Returns: + records(:obj:`List[dict]`): All records which the model needs. + """ + records = [] + for example in examples: + if isinstance(self.tokenizer, PretrainedTokenizer): + record_a = self.tokenizer(text=example.text_a, max_seq_len=self.max_seq_len, \ + pad_to_max_seq_len=True, return_length=True) + record_b = self.tokenizer(text=example.text_b, max_seq_len=self.max_seq_len, \ + pad_to_max_seq_len=True, return_length=True) + record = {'text_a': record_a, 'text_b': record_b} + else: + raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer".format(type(self.tokenizer))) + + if not record: + logger.info( + "The text %s has been dropped as it has no words in the vocab after tokenization." % example.text_a) + continue + if example.label: + record['label'] = self.label_map[example.label] + records.append(record) + return records + + def __getitem__(self, idx): + record = self.records[idx] + if isinstance(self.tokenizer, PretrainedTokenizer): + query_input_ids = np.array(record['text_a']['input_ids']) + query_token_type_ids = np.array(record['text_a']['token_type_ids']) + title_input_ids = np.array(record['text_b']['input_ids']) + title_token_type_ids = np.array(record['text_b']['token_type_ids']) + + if 'label' in record.keys(): + return query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids, \ + np.array(record['label'], dtype=np.int64) + else: + return query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids + else: + raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer".format(type(self.tokenizer))) + + def __len__(self): + return len(self.records) diff --git a/paddlehub/datasets/lcqmc.py b/paddlehub/datasets/lcqmc.py new file mode 100644 index 00000000..10919a52 --- /dev/null +++ b/paddlehub/datasets/lcqmc.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union +import os + +from paddlehub.env import DATA_HOME +from paddlehub.utils.download import download_data +from paddlehub.datasets.base_nlp_dataset import TextMatchingDataset +from paddlehub.text.bert_tokenizer import BertTokenizer +from paddlehub.text.tokenizer import CustomTokenizer + + +@download_data(url="https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz") +class LCQMC(TextMatchingDataset): + label_list = ['0', '1'] + + def __init__( + self, + tokenizer: Union[BertTokenizer, CustomTokenizer], + max_seq_len: int = 128, + mode: str = 'train', + ): + base_path = os.path.join(DATA_HOME, "lcqmc") + + if mode == 'train': + data_file = 'train.tsv' + elif mode == 'test': + data_file = 'test.tsv' + else: + data_file = 'dev.tsv' + super().__init__( + base_path=base_path, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + mode=mode, + data_file=data_file, + label_file=None, + label_list=self.label_list, + is_file_with_header=True, + ) + + +if __name__ == "__main__": + import paddlehub as hub + model = hub.Module(name='ernie_tiny') + tokenizer = model.get_tokenizer() + + ds = LCQMC(tokenizer=tokenizer, max_seq_len=128, mode='dev') \ No newline at end of file diff --git a/paddlehub/module/nlp_module.py b/paddlehub/module/nlp_module.py index 5973076e..e30ecd01 100644 --- a/paddlehub/module/nlp_module.py +++ b/paddlehub/module/nlp_module.py @@ -404,63 +404,74 @@ class TransformerModule(RunModule, TextServing): _tasks_supported = [ 'seq-cls', 'token-cls', + 'text-matching', ] - def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int, split_char: str): + def _convert_text_to_input(self, tokenizer, texts: List[str], max_seq_len: int, split_char: str): pad_to_max_seq_len = False if self.task is None else True if self.task == 'token-cls': # Extra processing of token-cls task - tokens = text[0].split(split_char) - text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens) + tokens = texts[0].split(split_char) + texts[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens) is_split_into_words = True else: is_split_into_words = False - if len(text) == 1: - if Version(paddlenlp.__version__) <= Version('2.0.0rc2'): - encoded_inputs = tokenizer.encode( - text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) - else: - encoded_inputs = tokenizer( - text=text[0], - max_seq_len=max_seq_len, - pad_to_max_seq_len=True, - is_split_into_words=is_split_into_words, - return_length=True) - elif len(text) == 2: - if Version(paddlenlp.__version__) <= Version('2.0.0rc2'): - encoded_inputs = tokenizer.encode( - text[0], text_pair=text[1], max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len) - else: - encoded_inputs = tokenizer( - text=text[0], - text_pair=text[1], - max_seq_len=max_seq_len, - pad_to_max_seq_len=True, - is_split_into_words=is_split_into_words, - return_length=True) + encoded_inputs = [] + if self.task == 'text-matching': + if len(texts) != 2: + raise RuntimeError( + 'The input texts must have two sequences, but got %d. Please check your inputs.' % len(texts)) + encoded_inputs.append(tokenizer(text=texts[0], text_pair=None, max_seq_len=max_seq_len, \ + pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True)) + encoded_inputs.append(tokenizer(text=texts[1], text_pair=None, max_seq_len=max_seq_len, \ + pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True)) else: - raise RuntimeError( - 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) + if len(texts) == 1: + if Version(paddlenlp.__version__) <= Version('2.0.0rc2'): + encoded_inputs.append(tokenizer.encode(texts[0], text_pair=None, \ + max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)) + else: + encoded_inputs.append(tokenizer(text=texts[0], max_seq_len=max_seq_len, \ + pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True)) + elif len(texts) == 2: + if Version(paddlenlp.__version__) <= Version('2.0.0rc2'): + encoded_inputs.append(tokenizer.encode(texts[0], text_pair=texts[1], \ + max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)) + else: + encoded_inputs.append(tokenizer(text=texts[0], text_pair=texts[1], max_seq_len=max_seq_len, \ + pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True)) + else: + raise RuntimeError( + 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) return encoded_inputs def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str): def _parse_batch(batch): - input_ids = [entry[0] for entry in batch] - segment_ids = [entry[1] for entry in batch] - return input_ids, segment_ids + if self.task != 'text-matching': + input_ids = [entry[0] for entry in batch] + segment_ids = [entry[1] for entry in batch] + return input_ids, segment_ids + else: + query_input_ids = [entry[0] for entry in batch] + query_segment_ids = [entry[1] for entry in batch] + title_input_ids = [entry[2] for entry in batch] + title_segment_ids = [entry[3] for entry in batch] + return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids tokenizer = self.get_tokenizer() examples = [] - for text in data: - encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char) - input_ids = encoded_inputs['input_ids'] - if Version(paddlenlp.__version__) >= Version('2.0.0rc5'): - token_type_ids = encoded_inputs['token_type_ids'] - else: - token_type_ids = encoded_inputs['segment_ids'] - - examples.append((input_ids, token_type_ids)) + for texts in data: + encoded_inputs = self._convert_text_to_input(tokenizer, texts, max_seq_len, split_char) + example = [] + for inp in encoded_inputs: + input_ids = inp['input_ids'] + if Version(paddlenlp.__version__) >= Version('2.0.0rc5'): + token_type_ids = inp['token_type_ids'] + else: + token_type_ids = inp['segment_ids'] + example.extend((input_ids, token_type_ids)) + examples.append(example) # Seperates data into some batches. one_batch = [] @@ -488,6 +499,9 @@ class TransformerModule(RunModule, TextServing): elif self.task == 'token-cls': predictions, avg_loss, metric = self( input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) + elif self.task == 'text-matching': + predictions, avg_loss, metric = self(query_input_ids=batch[0], query_token_type_ids=batch[1], \ + title_input_ids=batch[2], title_token_type_ids=batch[3], labels=batch[4]) self.metric.reset() return {'loss': avg_loss, 'metrics': metric} @@ -506,7 +520,9 @@ class TransformerModule(RunModule, TextServing): elif self.task == 'token-cls': predictions, avg_loss, metric = self( input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3]) - self.metric.reset() + elif self.task == 'text-matching': + predictions, avg_loss, metric = self(query_input_ids=batch[0], query_token_type_ids=batch[1], \ + title_input_ids=batch[2], title_token_type_ids=batch[3], labels=batch[4]) return {'metrics': metric} def get_embedding(self, data: List[List[str]], use_gpu=False): @@ -549,7 +565,8 @@ class TransformerModule(RunModule, TextServing): raise RuntimeError(f'Unknown task {self.task}, current tasks supported:\n' '1. seq-cls: sequence classification;\n' '2. token-cls: sequence labeling;\n' - '3. None: embedding') + '3. text-matching: text matching;\n' + '4. None: embedding') paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') @@ -557,26 +574,41 @@ class TransformerModule(RunModule, TextServing): results = [] self.eval() for batch in batches: - input_ids, segment_ids = batch - input_ids = paddle.to_tensor(input_ids) - segment_ids = paddle.to_tensor(segment_ids) - - if self.task == 'seq-cls': - probs = self(input_ids, segment_ids) + if self.task == 'text-matching': + query_input_ids, query_segment_ids, title_input_ids, title_segment_ids = batch + query_input_ids = paddle.to_tensor(query_input_ids) + query_segment_ids = paddle.to_tensor(query_segment_ids) + title_input_ids = paddle.to_tensor(title_input_ids) + title_segment_ids = paddle.to_tensor(title_segment_ids) + probs = self(query_input_ids=query_input_ids, query_token_type_ids=query_segment_ids, \ + title_input_ids=title_input_ids, title_token_type_ids=title_segment_ids) idx = paddle.argmax(probs, axis=1).numpy() idx = idx.tolist() labels = [self.label_map[i] for i in idx] results.extend(labels) - elif self.task == 'token-cls': - probs = self(input_ids, segment_ids) - batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len) - batch_ids = batch_ids.tolist() - token_labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids] - results.extend(token_labels) - elif self.task == None: - sequence_output, pooled_output = self(input_ids, segment_ids) - results.append([pooled_output.squeeze(0).numpy().tolist(), sequence_output.squeeze(0).numpy().tolist()]) - + else: + input_ids, segment_ids = batch + input_ids = paddle.to_tensor(input_ids) + segment_ids = paddle.to_tensor(segment_ids) + + if self.task == 'seq-cls': + probs = self(input_ids, segment_ids) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [self.label_map[i] for i in idx] + results.extend(labels) + elif self.task == 'token-cls': + probs = self(input_ids, segment_ids) + batch_ids = paddle.argmax(probs, axis=2).numpy() # (batch_size, max_seq_len) + batch_ids = batch_ids.tolist() + token_labels = [[self.label_map[i] for i in token_ids] for token_ids in batch_ids] + results.extend(token_labels) + elif self.task == None: + sequence_output, pooled_output = self(input_ids, segment_ids) + results.append([ + pooled_output.squeeze(0).numpy().tolist(), + sequence_output.squeeze(0).numpy().tolist() + ]) return results -- GitLab