未验证 提交 21fc5cb5 编写于 作者: K KP 提交者: GitHub

Add dataset, module task, and demo of text-matching (#1281)

* Add dataset, module task, and demo of text-matching
上级 ab3a163e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlehub as hub
if __name__ == '__main__':
data = [
['这个表情叫什么', '这个猫的表情叫什么'],
['什么是智能手环', '智能手环有什么用'],
['介绍几本好看的都市异能小说,要完结的!', '求一本好看点的都市异能小说,要完结的'],
['一只蜜蜂落在日历上(打一成语)', '一只蜜蜂停在日历上(猜一成语)'],
['一盒香烟不拆开能存放多久?', '一条没拆封的香烟能存放多久。'],
]
label_map = {0: 'similar', 1: 'dissimilar'}
model = hub.Module(
name='ernie_tiny',
version='2.0.1',
task='text-matching',
load_checkpoint='./checkpoint/best_model/model.pdparams',
label_map=label_map)
results = model.predict(data, max_seq_len=50, batch_size=1, use_gpu=True)
for idx, texts in enumerate(data):
print('TextA: {}\tTextB: {}\t Label: {}'.format(texts[0], texts[1], results[idx]))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddlehub as hub
from paddlehub.datasets import LCQMC
import ast
import argparse
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=128, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint")
parser.add_argument("--save_interval", type=int, default=2, help="Save checkpoint every n epoch.")
args = parser.parse_args()
if __name__ == '__main__':
model = hub.Module(name='ernie_tiny', version='2.0.1', task='text-matching')
tokenizer = model.get_tokenizer()
train_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train')
dev_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev')
test_dataset = LCQMC(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test')
optimizer = paddle.optimizer.AdamW(
learning_rate=args.learning_rate, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
trainer.train(
train_dataset,
epochs=args.num_epoch,
batch_size=args.batch_size,
eval_dataset=dev_dataset,
save_interval=args.save_interval,
)
trainer.evaluate(test_dataset, batch_size=args.batch_size)
......@@ -29,7 +29,7 @@ from paddlehub.utils.log import logger
@moduleinfo(
name="ernie_tiny",
version="2.0.1",
version="2.0.2",
summary="Baidu's ERNIE-tiny, Enhanced Representation through kNowledge IntEgration, tiny version, max_seq_len=512",
author="paddlepaddle",
author_email="",
......@@ -71,6 +71,12 @@ class ErnieTiny(nn.Layer):
self.metric = ChunkEvaluator(
label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
)
elif task == 'text-matching':
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-tiny', **kwargs)
self.dropout = paddle.nn.Dropout(0.1)
self.classifier = paddle.nn.Linear(self.model.config['hidden_size']*3, 2)
self.criterion = paddle.nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()
elif task is None:
self.model = ErnieModel.from_pretrained(pretrained_model_name_or_path='ernie-tiny', **kwargs)
else:
......@@ -84,8 +90,28 @@ class ErnieTiny(nn.Layer):
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
def forward(self,
input_ids=None,
token_type_ids=None,
position_ids=None,
attention_mask=None,
query_input_ids=None,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_input_ids=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None,
seq_lengths=None,
labels=None):
if self.task != 'text-matching':
result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
else:
query_result = self.model(query_input_ids, query_token_type_ids, query_position_ids, query_attention_mask)
title_result = self.model(title_input_ids, title_token_type_ids, title_position_ids, title_attention_mask)
if self.task == 'seq-cls':
logits = result
probs = F.softmax(logits, axis=1)
......@@ -108,6 +134,35 @@ class ErnieTiny(nn.Layer):
_, _, f1_score = map(float, self.metric.accumulate())
return token_level_probs, loss, {'f1_score': f1_score}
return token_level_probs
elif self.task == 'text-matching':
query_token_embedding, _ = query_result
query_token_embedding = self.dropout(query_token_embedding)
query_attention_mask = paddle.unsqueeze(
(query_input_ids != self.model.pad_token_id).astype(self.model.pooler.dense.weight.dtype), axis=2)
query_token_embedding = query_token_embedding * query_attention_mask
query_sum_embedding = paddle.sum(query_token_embedding, axis=1)
query_sum_mask = paddle.sum(query_attention_mask, axis=1)
query_mean = query_sum_embedding / query_sum_mask
title_token_embedding, _ = title_result
title_token_embedding = self.dropout(title_token_embedding)
title_attention_mask = paddle.unsqueeze(
(title_input_ids != self.model.pad_token_id).astype(self.model.pooler.dense.weight.dtype), axis=2)
title_token_embedding = title_token_embedding * title_attention_mask
title_sum_embedding = paddle.sum(title_token_embedding, axis=1)
title_sum_mask = paddle.sum(title_attention_mask, axis=1)
title_mean = title_sum_embedding / title_sum_mask
sub = paddle.abs(paddle.subtract(query_mean, title_mean))
projection = paddle.concat([query_mean, title_mean, sub], axis=-1)
logits = self.classifier(projection)
probs = F.softmax(logits)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
return probs
else:
sequence_output, pooled_output = result
return sequence_output, pooled_output
......
......@@ -17,3 +17,4 @@ from paddlehub.datasets.flowers import Flowers
from paddlehub.datasets.minicoco import MiniCOCO
from paddlehub.datasets.chnsenticorp import ChnSentiCorp
from paddlehub.datasets.msra_ner import MSRA_NER
from paddlehub.datasets.lcqmc import LCQMC
......@@ -469,3 +469,120 @@ class SeqLabelingDataset(BaseNLPDataset, paddle.io.Dataset):
def __len__(self):
return len(self.records)
class TextMatchingDataset(BaseNLPDataset, paddle.io.Dataset):
"""
The dataset class which is fit for all datatset of text matching.
"""
def __init__(self,
base_path: str,
tokenizer: PretrainedTokenizer,
max_seq_len: int = 128,
mode: str = "train",
data_file: str = None,
label_file: str = None,
label_list: list = None,
is_file_with_header: bool = False):
"""
Ags:
base_path (:obj:`str`): The directory to the whole dataset.
tokenizer (:obj:`PretrainedTokenizer`):
It tokenizes the text and encodes the data as model needed.
max_seq_len (:obj:`int`, `optional`, defaults to :128):
If set to a number, will limit the total sequence returned so that it has a maximum length.
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train, test or dev).
data_file(:obj:`str`, `optional`, defaults to :obj:`None`):
The data file name, which is relative to the base_path.
label_file(:obj:`str`, `optional`, defaults to :obj:`None`):
The label file name, which is relative to the base_path.
It is all labels of the dataset, one line one label.
label_list(:obj:`List[str]`, `optional`, defaults to :obj:`None`):
The list of all labels of the dataset
is_file_with_header(:obj:bool, `optional`, default to :obj: False) :
Whether or not the file is with the header introduction.
"""
super(TextMatchingDataset, self).__init__(
base_path=base_path,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
mode=mode,
data_file=data_file,
label_file=label_file,
label_list=label_list)
self.examples = self._read_file(self.data_file, is_file_with_header)
self.records = self._convert_examples_to_records(self.examples)
def _read_file(self, input_file, is_file_with_header: bool = False) -> List[InputExample]:
"""
Reads a tab separated value file.
Args:
input_file (:obj:str) : The file to be read.
is_file_with_header(:obj:bool, `optional`, default to :obj: False) :
Whether or not the file is with the header introduction.
Returns:
examples (:obj:`List[InputExample]`): All the input data.
"""
if not os.path.exists(input_file):
raise RuntimeError("The file {} is not found.".format(input_file))
else:
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = []
seq_id = 0
header = next(reader) if is_file_with_header else None
for line in reader:
example = InputExample(guid=seq_id, text_a=line[0], text_b=line[1], label=line[2])
seq_id += 1
examples.append(example)
return examples
def _convert_examples_to_records(self, examples: List[InputExample]) -> List[dict]:
"""
Converts all examples to records which the model needs.
Args:
examples(obj:`List[InputExample]`): All data examples returned by _read_file.
Returns:
records(:obj:`List[dict]`): All records which the model needs.
"""
records = []
for example in examples:
if isinstance(self.tokenizer, PretrainedTokenizer):
record_a = self.tokenizer(text=example.text_a, max_seq_len=self.max_seq_len, \
pad_to_max_seq_len=True, return_length=True)
record_b = self.tokenizer(text=example.text_b, max_seq_len=self.max_seq_len, \
pad_to_max_seq_len=True, return_length=True)
record = {'text_a': record_a, 'text_b': record_b}
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer".format(type(self.tokenizer)))
if not record:
logger.info(
"The text %s has been dropped as it has no words in the vocab after tokenization." % example.text_a)
continue
if example.label:
record['label'] = self.label_map[example.label]
records.append(record)
return records
def __getitem__(self, idx):
record = self.records[idx]
if isinstance(self.tokenizer, PretrainedTokenizer):
query_input_ids = np.array(record['text_a']['input_ids'])
query_token_type_ids = np.array(record['text_a']['token_type_ids'])
title_input_ids = np.array(record['text_b']['input_ids'])
title_token_type_ids = np.array(record['text_b']['token_type_ids'])
if 'label' in record.keys():
return query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids, \
np.array(record['label'], dtype=np.int64)
else:
return query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids
else:
raise RuntimeError("Unknown type of self.tokenizer: {}, it must be an instance of PretrainedTokenizer".format(type(self.tokenizer)))
def __len__(self):
return len(self.records)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import os
from paddlehub.env import DATA_HOME
from paddlehub.utils.download import download_data
from paddlehub.datasets.base_nlp_dataset import TextMatchingDataset
from paddlehub.text.bert_tokenizer import BertTokenizer
from paddlehub.text.tokenizer import CustomTokenizer
@download_data(url="https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz")
class LCQMC(TextMatchingDataset):
label_list = ['0', '1']
def __init__(
self,
tokenizer: Union[BertTokenizer, CustomTokenizer],
max_seq_len: int = 128,
mode: str = 'train',
):
base_path = os.path.join(DATA_HOME, "lcqmc")
if mode == 'train':
data_file = 'train.tsv'
elif mode == 'test':
data_file = 'test.tsv'
else:
data_file = 'dev.tsv'
super().__init__(
base_path=base_path,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
mode=mode,
data_file=data_file,
label_file=None,
label_list=self.label_list,
is_file_with_header=True,
)
if __name__ == "__main__":
import paddlehub as hub
model = hub.Module(name='ernie_tiny')
tokenizer = model.get_tokenizer()
ds = LCQMC(tokenizer=tokenizer, max_seq_len=128, mode='dev')
\ No newline at end of file
......@@ -404,40 +404,42 @@ class TransformerModule(RunModule, TextServing):
_tasks_supported = [
'seq-cls',
'token-cls',
'text-matching',
]
def _convert_text_to_input(self, tokenizer, text: List[str], max_seq_len: int, split_char: str):
def _convert_text_to_input(self, tokenizer, texts: List[str], max_seq_len: int, split_char: str):
pad_to_max_seq_len = False if self.task is None else True
if self.task == 'token-cls': # Extra processing of token-cls task
tokens = text[0].split(split_char)
text[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens)
tokens = texts[0].split(split_char)
texts[0], _ = reseg_token_label(tokenizer=tokenizer, tokens=tokens)
is_split_into_words = True
else:
is_split_into_words = False
if len(text) == 1:
encoded_inputs = []
if self.task == 'text-matching':
if len(texts) != 2:
raise RuntimeError(
'The input texts must have two sequences, but got %d. Please check your inputs.' % len(texts))
encoded_inputs.append(tokenizer(text=texts[0], text_pair=None, max_seq_len=max_seq_len, \
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
encoded_inputs.append(tokenizer(text=texts[1], text_pair=None, max_seq_len=max_seq_len, \
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
else:
if len(texts) == 1:
if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
encoded_inputs = tokenizer.encode(
text[0], text_pair=None, max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
encoded_inputs.append(tokenizer.encode(texts[0], text_pair=None, \
max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len))
else:
encoded_inputs = tokenizer(
text=text[0],
max_seq_len=max_seq_len,
pad_to_max_seq_len=True,
is_split_into_words=is_split_into_words,
return_length=True)
elif len(text) == 2:
encoded_inputs.append(tokenizer(text=texts[0], max_seq_len=max_seq_len, \
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
elif len(texts) == 2:
if Version(paddlenlp.__version__) <= Version('2.0.0rc2'):
encoded_inputs = tokenizer.encode(
text[0], text_pair=text[1], max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len)
encoded_inputs.append(tokenizer.encode(texts[0], text_pair=texts[1], \
max_seq_len=max_seq_len, pad_to_max_seq_len=pad_to_max_seq_len))
else:
encoded_inputs = tokenizer(
text=text[0],
text_pair=text[1],
max_seq_len=max_seq_len,
pad_to_max_seq_len=True,
is_split_into_words=is_split_into_words,
return_length=True)
encoded_inputs.append(tokenizer(text=texts[0], text_pair=texts[1], max_seq_len=max_seq_len, \
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
else:
raise RuntimeError(
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text))
......@@ -445,22 +447,31 @@ class TransformerModule(RunModule, TextServing):
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
def _parse_batch(batch):
if self.task != 'text-matching':
input_ids = [entry[0] for entry in batch]
segment_ids = [entry[1] for entry in batch]
return input_ids, segment_ids
else:
query_input_ids = [entry[0] for entry in batch]
query_segment_ids = [entry[1] for entry in batch]
title_input_ids = [entry[2] for entry in batch]
title_segment_ids = [entry[3] for entry in batch]
return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids
tokenizer = self.get_tokenizer()
examples = []
for text in data:
encoded_inputs = self._convert_text_to_input(tokenizer, text, max_seq_len, split_char)
input_ids = encoded_inputs['input_ids']
for texts in data:
encoded_inputs = self._convert_text_to_input(tokenizer, texts, max_seq_len, split_char)
example = []
for inp in encoded_inputs:
input_ids = inp['input_ids']
if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
token_type_ids = encoded_inputs['token_type_ids']
token_type_ids = inp['token_type_ids']
else:
token_type_ids = encoded_inputs['segment_ids']
examples.append((input_ids, token_type_ids))
token_type_ids = inp['segment_ids']
example.extend((input_ids, token_type_ids))
examples.append(example)
# Seperates data into some batches.
one_batch = []
......@@ -488,6 +499,9 @@ class TransformerModule(RunModule, TextServing):
elif self.task == 'token-cls':
predictions, avg_loss, metric = self(
input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
elif self.task == 'text-matching':
predictions, avg_loss, metric = self(query_input_ids=batch[0], query_token_type_ids=batch[1], \
title_input_ids=batch[2], title_token_type_ids=batch[3], labels=batch[4])
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric}
......@@ -506,7 +520,9 @@ class TransformerModule(RunModule, TextServing):
elif self.task == 'token-cls':
predictions, avg_loss, metric = self(
input_ids=batch[0], token_type_ids=batch[1], seq_lengths=batch[2], labels=batch[3])
self.metric.reset()
elif self.task == 'text-matching':
predictions, avg_loss, metric = self(query_input_ids=batch[0], query_token_type_ids=batch[1], \
title_input_ids=batch[2], title_token_type_ids=batch[3], labels=batch[4])
return {'metrics': metric}
def get_embedding(self, data: List[List[str]], use_gpu=False):
......@@ -549,7 +565,8 @@ class TransformerModule(RunModule, TextServing):
raise RuntimeError(f'Unknown task {self.task}, current tasks supported:\n'
'1. seq-cls: sequence classification;\n'
'2. token-cls: sequence labeling;\n'
'3. None: embedding')
'3. text-matching: text matching;\n'
'4. None: embedding')
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')
......@@ -557,6 +574,19 @@ class TransformerModule(RunModule, TextServing):
results = []
self.eval()
for batch in batches:
if self.task == 'text-matching':
query_input_ids, query_segment_ids, title_input_ids, title_segment_ids = batch
query_input_ids = paddle.to_tensor(query_input_ids)
query_segment_ids = paddle.to_tensor(query_segment_ids)
title_input_ids = paddle.to_tensor(title_input_ids)
title_segment_ids = paddle.to_tensor(title_segment_ids)
probs = self(query_input_ids=query_input_ids, query_token_type_ids=query_segment_ids, \
title_input_ids=title_input_ids, title_token_type_ids=title_segment_ids)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [self.label_map[i] for i in idx]
results.extend(labels)
else:
input_ids, segment_ids = batch
input_ids = paddle.to_tensor(input_ids)
segment_ids = paddle.to_tensor(segment_ids)
......@@ -575,8 +605,10 @@ class TransformerModule(RunModule, TextServing):
results.extend(token_labels)
elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids)
results.append([pooled_output.squeeze(0).numpy().tolist(), sequence_output.squeeze(0).numpy().tolist()])
results.append([
pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()
])
return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册