未验证 提交 84ffe966 编写于 作者: S Steffy-zxf 提交者: GitHub

add model api for paddlenlp (#4998)

* add model api

* update codes

* update codes

* update codes
上级 62eedcf2
......@@ -3,7 +3,7 @@ from functools import partial
from paddle.io import DistributedBatchSampler, DataLoader
from paddle.static import InputSpec
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer, ErnieForSequenceClassification, ErnieTokenizer
import numpy as np
import paddle
import paddlenlp
......@@ -18,10 +18,10 @@ def convert_example(example, tokenizer, max_seq_length=128):
return input_ids, segment_ids, label
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
train_ds, = paddlenlp.datasets.ChnSentiCorp.get_datasets(['train'])
paddle.set_device('gpu')
train_ds = paddlenlp.datasets.ChnSentiCorp.get_datasets(['train'])
label_list = train_ds.get_labels()
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
trans_func = partial(convert_example, tokenizer=tokenizer)
train_ds = train_ds.apply(trans_func)
batchify_fn = lambda samples, fn=Tuple(
......@@ -35,6 +35,8 @@ train_loader = DataLoader(
collate_fn=batchify_fn,
return_list=True)
model = paddlenlp.models.Ernie(
'ernie-1.0', task='seq-cls', num_classes=len(label_list))
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
optimizer = paddle.optimizer.AdamW(
......
......@@ -3,7 +3,7 @@ from functools import partial
from paddle.io import DistributedBatchSampler, DataLoader
from paddle.static import InputSpec
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers import ErnieTokenizer
import numpy as np
import paddle
import paddlenlp
......@@ -18,11 +18,11 @@ def convert_example(example, tokenizer, max_seq_length=128):
return input_ids, segment_ids, label
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
paddle.set_device('gpu')
train_ds, dev_ds = paddlenlp.datasets.ChnSentiCorp.get_datasets(
['train', 'dev'])
label_list = train_ds.get_labels()
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
trans_func = partial(convert_example, tokenizer=tokenizer)
train_ds = train_ds.apply(trans_func)
dev_ds = dev_ds.apply(trans_func)
......@@ -43,6 +43,8 @@ dev_loader = DataLoader(
collate_fn=batchify_fn,
return_list=True)
model = paddlenlp.models.Ernie(
'ernie-1.0', task='seq-cls', num_classes=len(label_list))
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
optimizer = paddle.optimizer.AdamW(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class RunConfig(object):
"""
Running Config Setting.
Args:
save_dir (obj:`str`): The directory to save checkpoint during training.
use_gpu (obj:`bool`, optinal, defaults to obj:`False`): Whether use GPU for training, input should be True or False.
lr (obj:`float`, optinal, defaults to 5e-4): Learning rate used to train.
batch_size (obj:`int`, optinal, defaults to 1): Total examples' number of a batch.
epochs (obj:`int`, optinal, defaults to 1): Number of epoches for training.
log_freq (obj:`int`, optinal, defaults to 10): The frequency, in number of steps, the training logs are printed.
eval_freq (obj:`int`, optinal, defaults to 1): The frequency, in number of epochs, an evalutation is performed.
save_freq (obj:`int`, optinal, defaults to 1): The frequency, in number of epochs, to save checkpoints.
"""
def __init__(self,
save_dir,
use_gpu=0,
lr=5e-4,
batch_size=1,
epochs=1,
log_freq=10,
eval_freq=1,
save_freq=1):
self.save_dir = save_dir
self.use_gpu = use_gpu
self.batch_size = batch_size
self.lr = lr
self.epochs = epochs
self.log_freq = log_freq
self.eval_freq = eval_freq
self.save_freq = save_freq
self._place = paddle.set_device(
"gpu") if self.use_gpu else paddle.set_device("cpu")
def get_save_dir(self):
return self.save_dir
def get_running_place(self):
return self._place
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import csv
import io
import os
import paddle
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = {}
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n").split("\t")[0]
vocab[token] = index
return vocab
def convert_tokens_to_ids(tokens, vocab):
""" Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
ids = []
for token in tokens:
wid = vocab.get(token, None)
if wid:
ids.append(wid)
return ids
class ChnSentiCorp(paddle.io.Dataset):
"""
ChnSentiCorp is a dataset for chinese sentiment classification,
which was published by Tan Songbo at ICT of Chinese Academy of Sciences.
Args:
base_path (:obj:`str`) : The dataset file path, which contains train.tsv, dev.tsv and test.tsv.
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train, test or dev).
"""
def __init__(self, base_path, vocab, mode='train'):
if mode == 'train':
data_file = 'train.tsv'
elif mode == 'test':
data_file = 'test.tsv'
else:
data_file = 'dev.tsv'
self.data_file = os.path.join(base_path, data_file)
self.label_list = ["0", "1"]
self.label_map = {
item: index
for index, item in enumerate(self.label_list)
}
self.vocab = vocab
self.raw_examples = self._read_file(self.data_file)
def _read_file(self, input_file):
"""
Reads a tab separated value file.
Args:
input_file (:obj:`str`) : The file to be read.
Returns:
examples (:obj:`list`): All the input data.
"""
if not os.path.exists(input_file):
raise RuntimeError("The file {} is not found.".format(input_file))
else:
Example = namedtuple('Example', ['text', 'label', 'seq_len'])
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = []
header = next(reader)
for line in reader:
tokens = line[0].strip().split(' ')
ids = convert_tokens_to_ids(tokens, self.vocab)
example = Example(
text=ids,
label=self.label_map[line[1]],
seq_len=len(ids))
examples.append(example)
return examples
def __getitem__(self, idx):
return self.raw_examples[idx]
def __len__(self):
return len(self.raw_examples)
if __name__ == "__main__":
vocab = load_vocab('./senta_data/word_dict.txt')
train_dataset = ChnSentiCorp(
base_path='./senta_data', vocab=vocab, mode='train')
dev_dataset = ChnSentiCorp(
base_path='./senta_data', vocab=vocab, mode='dev')
test_dataset = ChnSentiCorp(
base_path='./senta_data', vocab=vocab, mode='test')
index = 0
for example in train_dataset:
print("%s \t %s \t %s" % (example.text, example.label, example.seq_len))
index += 1
if index > 3:
break
......@@ -11,187 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import argparse
import os
import random
import time
import jieba
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlenlp as ppnlp
from config import RunConfig
from data import ChnSentiCorp, convert_tokens_to_ids, load_vocab
from model import BoWModel, LSTMModel, GRUModel, RNNModel, BiLSTMAttentionModel, TextCNNModel
from model import SelfAttention, SelfInteractiveAttention
from utils import load_vocab, generate_batch, preprocess_prediction_data
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./word_dict.txt", help="The path to vocabulary.")
parser.add_argument('--network_name', type=str, default="bilstm_attn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default=None, required=True, help="The path of model parameter to be loaded.")
parser.add_argument('--network_name', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default='./chekpoints/final.pdparams', help="The path of model parameter to be loaded.")
args = parser.parse_args()
# yapf: enable
def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
"""
Padded the texts to the max sequence length if the length of text is lower than it.
Unless it truncates the text.
Args:
texts(obj:`list`): Texts which contrains a sequence of word ids.
max_seq_len(obj:`int`): Max sequence length.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
"""
for index, text in enumerate(texts):
seq_len = len(text)
if seq_len < max_seq_len:
padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
new_text = text + padded_tokens
texts[index] = new_text
elif seq_len > max_seq_len:
new_text = text[:max_seq_len]
texts[index] = new_text
def generate_batch(batch, pad_token_id=0, return_label=True):
"""
Generates a batch whose text will be padded to the max sequence length in the batch.
Args:
batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
"""
seq_lens = [entry.seq_len for entry in batch]
batch_max_seq_len = max(seq_lens)
texts = [entry.text for entry in batch]
pad_texts_to_max_seq_len(texts, batch_max_seq_len, pad_token_id)
if return_label:
labels = [[entry.label] for entry in batch]
return texts, seq_lens, labels
else:
return texts, seq_lens
def create_model(vocab_size, num_labels, network_name='bilstm', padding_idx=0):
"""
Creats model which uses to text classification. It should be BoW, LSTM/BiLSTM, GRU/BiGRU.
Args:
vocab_size(obj:`int`): The vocabulary size.
num_labels(obj:`int`): All the labels that the data has.
network_name(obj: `str`, optional, defaults to `lstm`): which network you would like.
padding_idx(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
model(obj:`paddle.nn.Layer`): A model.
"""
if network_name == 'bow':
network = BoWModel(
vocab_size, num_labels=num_labels, padding_idx=pad_token_id)
elif network_name == 'bilstm':
# direction choice: forward, backword, bidirectional
network = LSTMModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'bigru':
# direction choice: forward, backword, bidirectional
network = GRUModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'birnn':
# direction choice: forward, backword, bidirectional
network = RNNModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'lstm':
# direction choice: forward, backword, bidirectional
network = LSTMModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'gru':
# direction choice: forward, backword, bidirectional
network = GRUModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'rnn':
# direction choice: forward, backword, bidirectional
network = RNNModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'bilstm_attn':
lstm_hidden_size = 196
attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size)
network = BiLSTMAttentionModel(
attention_layer=attention,
vocab_size=vocab_size,
lstm_hidden_size=lstm_hidden_size,
num_labels=num_labels,
padding_idx=pad_token_id)
elif network_name == 'textcnn':
network = TextCNNModel(
vocab_size=vocab_size,
padding_idx=pad_token_id,
num_labels=num_labels)
else:
raise ValueError(
"Unknown network: %s, it must be one of bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn."
% network_name)
model = paddle.Model(network)
return model
def preprocess_prediction_data(data):
"""
It process the prediction data as the format used as training.
Args:
data (obj:`List[str]`): The prediction data whose each element is a tokenized text.
Returns:
examples (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
A Example object contains `text`(word_ids) and `se_len`(sequence length).
"""
Example = namedtuple('Example', ['text', 'seq_len'])
examples = []
for text in data:
tokens = " ".join(jieba.cut(text)).split(' ')
ids = convert_tokens_to_ids(tokens, vocab)
example = Example(text=ids, seq_len=len(ids))
examples.append(example)
return examples
def predict(model, data, label_map, collate_fn, batch_size=1):
def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0):
"""
Predicts the data labels.
......@@ -203,6 +41,7 @@ def predict(model, data, label_map, collate_fn, batch_size=1):
collate_fn(obj: `callable`): function to generate mini-batch data by merging
the sample list.
batch_size(obj:`int`, defaults to 1): The number of batch.
pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
Returns:
results(obj:`dict`): All the predictions labels.
......@@ -221,14 +60,13 @@ def predict(model, data, label_map, collate_fn, batch_size=1):
batches.append(one_batch)
results = []
model.network.eval()
model.eval()
for batch in batches:
texts, seq_lens = collate_fn(
batch, pad_token_id=pad_token_id, return_label=False)
texts = paddle.to_tensor(texts)
seq_lens = paddle.to_tensor(seq_lens)
logits = model.network(texts, seq_lens)
probs = F.softmax(logits, axis=-1)
probs = model(texts, seq_lens)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [label_map[i] for i in idx]
......@@ -237,27 +75,20 @@ def predict(model, data, label_map, collate_fn, batch_size=1):
if __name__ == "__main__":
paddle.set_device("gpu") if args.use_gpu else paddle.set_device("cpu")
# Loads vocab.
vocab = load_vocab(args.vocab_path)
if '[PAD]' not in vocab:
pad_token_id = len(vocab)
vocab['[PAD]'] = pad_token_id
else:
pad_token_id = vocab['[PAD]']
label_map = {0: 'negative', 1: 'positive'}
paddle.set_device("gpu") if args.use_gpu else paddle.set_device("cpu")
# Constructs the newtork.
model = create_model(
len(vocab),
num_labels=len(label_map),
network_name=args.network_name.lower(),
padding_idx=pad_token_id)
model = ppnlp.models.Senta(
network_name=args.network_name,
vocab_size=len(vocab),
num_classes=len(label_map))
# Loads model parameters.
model.load(args.params_path)
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
print("Loaded parameters from %s" % args.params_path)
# Firstly pre-processing prediction data and then do predict.
......@@ -266,7 +97,8 @@ if __name__ == "__main__":
'怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片',
'作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。',
]
examples = preprocess_prediction_data(data)
examples = preprocess_prediction_data(data, vocab)
results = predict(
model,
examples,
......@@ -275,4 +107,4 @@ if __name__ == "__main__":
collate_fn=generate_batch)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text, results[idx]))
print('Data: {} \t Label: {}'.format(text, results[idx]))
......@@ -14,80 +14,27 @@
from functools import partial
import argparse
import os
import random
import time
import jieba
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlenlp as ppnlp
from paddlenlp.datasets import ChnSentiCorp
from config import RunConfig
from data import ChnSentiCorp, convert_tokens_to_ids, load_vocab
from model import BoWModel, LSTMModel, GRUModel, RNNModel, BiLSTMAttentionModel, TextCNNModel
from model import SelfAttention, SelfInteractiveAttention
from utils import load_vocab, generate_batch, convert_example
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--epochs", type=int, default=3, help="Number of epoches for training.")
parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument('--use_gpu', type=eval, default=True, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate used to train.")
parser.add_argument("--save_dir", type=str, default='chekpoints/', help="Directory to save model checkpoint")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./word_dict.txt", help="The directory to dataset.")
parser.add_argument('--network_name', type=str, default="bilstm_attn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
parser.add_argument('--network_name', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
args = parser.parse_args()
# yapf: enable
def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
"""
Padded the texts to the max sequence length if the length of text is lower than it.
Unless it truncates the text.
Args:
texts(obj:`list`): Texts which contrains a sequence of word ids.
max_seq_len(obj:`int`): Max sequence length.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
"""
for index, text in enumerate(texts):
seq_len = len(text)
if seq_len < max_seq_len:
padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
new_text = text + padded_tokens
texts[index] = new_text
elif seq_len > max_seq_len:
new_text = text[:max_seq_len]
texts[index] = new_text
def generate_batch(batch, pad_token_id=0, return_label=True):
"""
Generates a batch whose text will be padded to the max sequence length in the batch.
Args:
batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
"""
seq_lens = [entry[1] for entry in batch]
batch_max_seq_len = max(seq_lens)
texts = [entry[0] for entry in batch]
pad_texts_to_max_seq_len(texts, batch_max_seq_len, pad_token_id)
if return_label:
labels = [[entry[-1]] for entry in batch]
return texts, seq_lens, labels
else:
return texts, seq_lens
def create_dataloader(dataset,
trans_fn=None,
mode='train',
......@@ -113,172 +60,38 @@ def create_dataloader(dataset,
if mode == 'train' and use_gpu:
sampler = paddle.io.DistributedBatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=True)
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=sampler,
return_list=True,
collate_fn=lambda batch: generate_batch(batch,
pad_token_id=pad_token_id))
else:
shuffle = True if mode == 'train' else False
sampler = paddle.io.BatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=shuffle)
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=sampler,
return_list=True,
collate_fn=lambda batch: generate_batch(batch,
pad_token_id=pad_token_id))
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=sampler,
return_list=True,
collate_fn=lambda batch: generate_batch(batch, pad_token_id=pad_token_id))
return dataloader
def create_model(vocab_size, num_labels, network_name='bilstm', pad_token_id=0):
"""
Creats model which uses to text classification. It should be BoW, LSTM/BiLSTM, GRU/BiGRU.
Args:
vocab_size(obj:`int`): The vocabulary size.
num_labels(obj:`int`): All the labels that the data has.
network_name(obj: `str`, optional, defaults to `lstm`): which network you would like.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
model(obj:`paddle.nn.Layer`): A model.
"""
if network_name == 'bow':
network = BoWModel(
vocab_size, num_labels=num_labels, padding_idx=pad_token_id)
elif network_name == 'bilstm':
# direction choice: forward, backword, bidirectional
network = LSTMModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'bigru':
# direction choice: forward, backword, bidirectional
network = GRUModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'birnn':
# direction choice: forward, backword, bidirectional
network = RNNModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'lstm':
# direction choice: forward, backword, bidirectional
network = LSTMModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'gru':
# direction choice: forward, backword, bidirectional
network = GRUModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'rnn':
# direction choice: forward, backword, bidirectional
network = RNNModel(
vocab_size=vocab_size,
num_labels=num_labels,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'bilstm_attn':
lstm_hidden_size = 196
attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size)
network = BiLSTMAttentionModel(
attention_layer=attention,
vocab_size=vocab_size,
lstm_hidden_size=lstm_hidden_size,
num_labels=num_labels,
padding_idx=pad_token_id)
elif network_name == 'textcnn':
network = TextCNNModel(
vocab_size=vocab_size,
padding_idx=pad_token_id,
num_labels=num_labels)
else:
raise ValueError(
"Unknown network: %s, it must be one of bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn."
% network_name)
model = paddle.Model(network)
return model
def convert_example(example, vocab, unk_token_id=1, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks.
It use `jieba.cut` to tokenize text.
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
vocab(obj:`dict`): The vocabulary.
unk_token_id(obj:`int`, defaults to 1): The unknown token id.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of token ids.s
valid_length(obj:`int`): The input sequence valid length.
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
"""
# tokenize raw text and convert the token to ids
# tokens_raw = ' '.join(jieba.cut(example[0]).split(' ')
input_ids = []
for token in jieba.cut(example[0]):
token_id = vocab.get(token, unk_token_id)
input_ids.append(token_id)
valid_length = len(input_ids)
if not is_test:
label = np.array(example[-1], dtype="int64")
return input_ids, valid_length, label
else:
return input_ids, valid_length
if __name__ == "__main__":
# Running config setting.
config = RunConfig(
save_dir=args.save_dir,
use_gpu=args.use_gpu,
lr=args.lr,
batch_size=args.batch_size,
epochs=args.epochs)
paddle.set_device('gpu') if args.use_gpu else paddle.set_device('cpu')
# Loads vocab.
if not os.path.exists(args.vocab_path):
raise RuntimeError('The vocab_path can not be found in the path %s' %
args.vocab_path)
vocab = load_vocab(args.vocab_path)
if '[PAD]' not in vocab:
pad_token_id = len(vocab)
vocab['[PAD]'] = pad_token_id
else:
pad_token_id = vocab['[PAD]']
# Loads dataset.
train_dataset, dev_dataset, test_dataset = ppnlp.datasets.ChnSentiCorp.get_datasets(
train_ds, dev_ds, test_ds = ChnSentiCorp.get_datasets(
['train', 'dev', 'test'])
# Constructs the newtork.
model = create_model(
len(vocab),
len(train_dataset.get_labels()),
network_name=args.network_name.lower(),
pad_token_id=pad_token_id)
label_list = train_ds.get_labels()
model = ppnlp.models.Senta(
network_name=args.network_name,
vocab_size=len(vocab),
num_classes=len(label_list))
model = paddle.Model(model)
# Reads data and generates mini-batches.
trans_fn = partial(
......@@ -287,30 +100,21 @@ if __name__ == "__main__":
unk_token_id=vocab['[UNK]'],
is_test=False)
train_loader = create_dataloader(
train_dataset,
trans_fn=trans_fn,
batch_size=config.batch_size,
mode='train',
pad_token_id=pad_token_id)
train_ds, trans_fn=trans_fn, batch_size=args.batch_size, mode='train')
dev_loader = create_dataloader(
dev_dataset,
dev_ds,
trans_fn=trans_fn,
batch_size=config.batch_size,
mode='validation',
pad_token_id=pad_token_id)
batch_size=args.batch_size,
mode='validation')
test_loader = create_dataloader(
test_dataset,
trans_fn=trans_fn,
batch_size=config.batch_size,
mode='test',
pad_token_id=pad_token_id)
test_ds, trans_fn=trans_fn, batch_size=args.batch_size, mode='test')
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(), learning_rate=config.lr)
parameters=model.parameters(), learning_rate=args.lr)
# Defines loss and metric.
criterion = paddle.nn.CrossEntropyLoss()
metric = paddle.metric.Accuracy(name="acc_accumulation")
metric = paddle.metric.Accuracy()
model.prepare(optimizer, criterion, metric)
......@@ -322,12 +126,9 @@ if __name__ == "__main__":
# Starts training and evaluating.
model.fit(train_loader,
dev_loader,
epochs=config.epochs,
eval_freq=config.eval_freq,
log_freq=config.eval_freq,
save_dir=args.save_dir,
save_freq=config.save_freq)
epochs=args.epochs,
save_dir=args.save_dir)
# Finally tests model.
results = model.evaluate(test_loader)
print("Finally test acc: %.5f" % results['acc_accumulation'])
print("Finally test acc: %.5f" % results['acc'])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jieba
import numpy as np
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = {}
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n").split("\t")[0]
vocab[token] = index
return vocab
def convert_ids_to_tokens(wids, inversed_vocab):
""" Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
tokens = []
for wid in wids:
wstr = inversed_vocab.get(wid, None)
if wstr:
tokens.append(wstr)
return tokens
def convert_tokens_to_ids(tokens, vocab):
""" Converts a token id (or a sequence of id) in a token string
(or a sequence of tokens), using the vocabulary.
"""
ids = []
unk_id = vocab.get('[UNK]', None)
for token in tokens:
wid = vocab.get(token, unk_id)
if wid:
ids.append(wid)
return ids
def convert_example(example, vocab, unk_token_id=1, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks.
It use `jieba.cut` to tokenize text.
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
vocab(obj:`dict`): The vocabulary.
unk_token_id(obj:`int`, defaults to 1): The unknown token id.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of token ids.s
valid_length(obj:`int`): The input sequence valid length.
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
"""
input_ids = []
for token in jieba.cut(example[0]):
token_id = vocab.get(token, unk_token_id)
input_ids.append(token_id)
valid_length = len(input_ids)
if not is_test:
label = np.array(example[-1], dtype="int64")
return input_ids, valid_length, label
else:
return input_ids, valid_length
def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
"""
Padded the texts to the max sequence length if the length of text is lower than it.
Unless it truncates the text.
Args:
texts(obj:`list`): Texts which contrains a sequence of word ids.
max_seq_len(obj:`int`): Max sequence length.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
"""
for index, text in enumerate(texts):
seq_len = len(text)
if seq_len < max_seq_len:
padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
new_text = text + padded_tokens
texts[index] = new_text
elif seq_len > max_seq_len:
new_text = text[:max_seq_len]
texts[index] = new_text
def generate_batch(batch, pad_token_id=0, return_label=True):
"""
Generates a batch whose text will be padded to the max sequence length in the batch.
Args:
batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
"""
seq_lens = [entry[1] for entry in batch]
batch_max_seq_len = max(seq_lens)
texts = [entry[0] for entry in batch]
pad_texts_to_max_seq_len(texts, batch_max_seq_len, pad_token_id)
if return_label:
labels = [[entry[-1]] for entry in batch]
return texts, seq_lens, labels
else:
return texts, seq_lens
def convert_example(example, vocab, unk_token_id=1, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks.
It use `jieba.cut` to tokenize text.
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
vocab(obj:`dict`): The vocabulary.
unk_token_id(obj:`int`, defaults to 1): The unknown token id.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
input_ids(obj:`list[int]`): The list of token ids.s
valid_length(obj:`int`): The input sequence valid length.
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
"""
input_ids = []
for token in jieba.cut(example[0]):
token_id = vocab.get(token, unk_token_id)
input_ids.append(token_id)
valid_length = len(input_ids)
if not is_test:
label = np.array(example[-1], dtype="int64")
return input_ids, valid_length, label
else:
return input_ids, valid_length
def preprocess_prediction_data(data, vocab):
"""
It process the prediction data as the format used as training.
Args:
data (obj:`List[str]`): The prediction data whose each element is a tokenized text.
Returns:
examples (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
A Example object contains `text`(word_ids) and `se_len`(sequence length).
"""
examples = []
for text in data:
tokens = " ".join(jieba.cut(text)).split(' ')
ids = convert_tokens_to_ids(tokens, vocab)
examples.append([ids, len(ids)])
return examples
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import paddle
import paddlenlp as ppnlp
from utils import load_vocab, generate_batch, preprocess_prediction_data
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./data/term2id.dict", help="The path to vocabulary.")
parser.add_argument('--network_name', type=str, default="lstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default='./chekpoints/final.pdparams', help="The path of model parameter to be loaded.")
args = parser.parse_args()
# yapf: enable
def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0):
"""
Predicts the data labels.
Args:
model (obj:`paddle.nn.Layer`): A model to classify texts.
data (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.
A Example object contains `text`(word_ids) and `se_len`(sequence length).
label_map(obj:`dict`): The label id (key) to label str (value) map.
collate_fn(obj: `callable`): function to generate mini-batch data by merging
the sample list.
batch_size(obj:`int`, defaults to 1): The number of batch.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
results(obj:`dict`): All the predictions labels.
"""
# Seperates data into some batches.
batches = []
one_batch = []
for example in data:
one_batch.append(example)
if len(one_batch) == batch_size:
batches.append(one_batch)
one_batch = []
if one_batch:
# The last batch whose size is less than the config batch_size setting.
batches.append(one_batch)
results = []
model.eval()
for batch in batches:
queries, titles, query_seq_lens, title_seq_lens = collate_fn(
batch, pad_token_id=pad_token_id, return_label=False)
queries = paddle.to_tensor(queries)
titles = paddle.to_tensor(titles)
query_seq_lens = paddle.to_tensor(query_seq_lens)
title_seq_lens = paddle.to_tensor(title_seq_lens)
probs = model(queries, titles, query_seq_lens, title_seq_lens)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
labels = [label_map[i] for i in idx]
results.extend(labels)
return results
if __name__ == "__main__":
paddle.set_device("gpu") if args.use_gpu else paddle.set_device("cpu")
# Loads vocab.
vocab = load_vocab(args.vocab_path)
label_map = {0: 'dissimilar', 1: 'similar'}
# Constructs the newtork.
model = ppnlp.models.SimNet(
network_name=args.network_name,
vocab_size=len(vocab),
num_classes=len(label_map))
# Loads model parameters.
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
print("Loaded parameters from %s" % args.params_path)
# Firstly pre-processing prediction data and then do predict.
data = [
['世界上什么东西最小', '世界上什么东西最小?'],
['光眼睛大就好看吗', '眼睛好看吗?'],
['小蝌蚪找妈妈怎么样', '小蝌蚪找妈妈是谁画的'],
]
examples = preprocess_prediction_data(data, vocab)
results = predict(
model,
examples,
label_map=label_map,
batch_size=args.batch_size,
collate_fn=generate_batch)
for idx, text in enumerate(data):
print('Data: {} \t Label: {}'.format(text, results[idx]))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
import random
import time
import paddle
import paddlenlp as ppnlp
from paddlenlp.datasets import LCQMC
from utils import load_vocab, generate_batch, convert_example
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--epochs", type=int, default=3, help="Number of epoches for training.")
parser.add_argument('--use_gpu', type=eval, default=True, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate used to train.")
parser.add_argument("--save_dir", type=str, default='chekpoints/', help="Directory to save model checkpoint")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./data/term2id.dict", help="The directory to dataset.")
parser.add_argument('--network', type=str, default="cnn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
args = parser.parse_args()
# yapf: enable
def create_dataloader(dataset,
trans_fn=None,
mode='train',
batch_size=1,
use_gpu=False,
pad_token_id=0):
"""
Creats dataloader.
Args:
dataset(obj:`paddle.io.Dataset`): Dataset instance.
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
use_gpu(obj:`bool`, optional, defaults to obj:`False`): Whether to use gpu to run.
pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
Returns:
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
"""
if trans_fn:
dataset = dataset.apply(trans_fn, lazy=True)
if mode == 'train' and use_gpu:
sampler = paddle.io.DistributedBatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=True)
else:
shuffle = True if mode == 'train' else False
sampler = paddle.io.BatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=shuffle)
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=sampler,
return_list=True,
collate_fn=lambda batch: generate_batch(batch, pad_token_id=pad_token_id))
return dataloader
if __name__ == "__main__":
paddle.set_device('gpu') if args.use_gpu else paddle.set_device('cpu')
# Loads vocab.
if not os.path.exists(args.vocab_path):
raise RuntimeError('The vocab_path can not be found in the path %s' %
args.vocab_path)
vocab = load_vocab(args.vocab_path)
# Loads dataset.
train_ds, dev_dataset, test_ds = LCQMC.get_datasets(
['train', 'dev', 'test'])
# Constructs the newtork.
label_list = train_ds.get_labels()
model = ppnlp.models.SimNet(
network=args.network,
vocab_size=len(vocab),
num_classes=len(label_list))
model = paddle.Model(model)
# Reads data and generates mini-batches.
trans_fn = partial(convert_example, vocab=vocab, is_test=False)
train_loader = create_dataloader(
train_ds, trans_fn=trans_fn, batch_size=args.batch_size, mode='train')
dev_loader = create_dataloader(
dev_dataset,
trans_fn=trans_fn,
batch_size=args.batch_size,
mode='validation')
test_loader = create_dataloader(
test_ds, trans_fn=trans_fn, batch_size=args.batch_size, mode='test')
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(), learning_rate=args.lr)
# Defines loss and metric.
criterion = paddle.nn.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
model.prepare(optimizer, criterion, metric)
# Loads pre-trained parameters.
if args.init_from_ckpt:
model.load(args.init_from_ckpt)
print("Loaded checkpoint from %s" % args.init_from_ckpt)
# Starts training and evaluating.
model.fit(
train_loader,
dev_loader,
epochs=args.epochs,
save_dir=args.save_dir, )
# Finally tests model.
results = model.evaluate(test_loader)
print("Finally test acc: %.5f" % results['acc'])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jieba
import numpy as np
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = {}
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n").split("\t")[0]
vocab[token] = index
return vocab
def convert_ids_to_tokens(wids, inversed_vocab):
""" Converts a token string (or a sequence of tokens) in a single integer id
(or a sequence of ids), using the vocabulary.
"""
tokens = []
for wid in wids:
wstr = inversed_vocab.get(wid, None)
if wstr:
tokens.append(wstr)
return tokens
def convert_tokens_to_ids(tokens, vocab):
""" Converts a token id (or a sequence of id) in a token string
(or a sequence of tokens), using the vocabulary.
"""
ids = []
unk_id = vocab.get('[UNK]', None)
for token in tokens:
wid = vocab.get(token, unk_id)
if wid:
ids.append(wid)
return ids
def pad_texts_to_max_seq_len(texts, max_seq_len, pad_token_id=0):
"""
Padded the texts to the max sequence length if the length of text is lower than it.
Unless it truncates the text.
Args:
texts(obj:`list`): Texts which contrains a sequence of word ids.
max_seq_len(obj:`int`): Max sequence length.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
"""
for index, text in enumerate(texts):
seq_len = len(text)
if seq_len < max_seq_len:
padded_tokens = [pad_token_id for _ in range(max_seq_len - seq_len)]
new_text = text + padded_tokens
texts[index] = new_text
elif seq_len > max_seq_len:
new_text = text[:max_seq_len]
texts[index] = new_text
def generate_batch(batch, pad_token_id=0, return_label=True):
"""
Generates a batch whose text will be padded to the max sequence length in the batch.
Args:
batch(obj:`List[Example]`) : One batch, which contains texts, labels and the true sequence lengths.
pad_token_id(obj:`int`, optinal, defaults to 0) : The pad token index.
Returns:
batch(:obj:`Tuple[list]`): The batch data which contains texts, seq_lens and labels.
"""
queries = [entry[0] for entry in batch]
titles = [entry[1] for entry in batch]
query_seq_lens = [entry[2] for entry in batch]
title_seq_lens = [entry[3] for entry in batch]
query_batch_max_seq_len = max(query_seq_lens)
pad_texts_to_max_seq_len(queries, query_batch_max_seq_len, pad_token_id)
title_batch_max_seq_len = max(title_seq_lens)
pad_texts_to_max_seq_len(titles, title_batch_max_seq_len, pad_token_id)
if return_label:
labels = [entry[-1] for entry in batch]
return queries, titles, query_seq_lens, title_seq_lens, labels
else:
return queries, titles, query_seq_lens, title_seq_lens
def convert_example(example, vocab, unk_token_id=1, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks.
It use `jieba.cut` to tokenize text.
Args:
example(obj:`list[str]`): List of input data, containing text and label if it have label.
vocab(obj:`dict`): The vocabulary.
unk_token_id(obj:`int`, defaults to 1): The unknown token id.
is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.
Returns:
query_ids(obj:`list[int]`): The list of query ids.
title_ids(obj:`list[int]`): The list of title ids.
query_seq_len(obj:`int`): The input sequence query length.
title_seq_len(obj:`int`): The input sequence title length.
label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.
"""
query, title = example[0], example[1]
query_tokens = jieba.lcut(query)
title_tokens = jieba.lcut(title)
query_ids = convert_tokens_to_ids(query_tokens, vocab)
query_seq_len = len(query_ids)
title_ids = convert_tokens_to_ids(title_tokens, vocab)
title_seq_len = len(title_ids)
if not is_test:
label = np.array(example[-1], dtype="int64")
return query_ids, title_ids, query_seq_len, title_seq_len, label
else:
return query_ids, title_ids, query_seq_len, title_seq_len
def preprocess_prediction_data(data, vocab):
"""
It process the prediction data as the format used as training.
Args:
data (obj:`List[List[str, str]]`):
The prediction data whose each element is a text pair.
Each text will be tokenized by jieba.lcut() function.
Returns:
examples (obj:`list`): The processed data whose each element
is a `list` object, which contains
- query_ids(obj:`list[int]`): The list of query ids.
- title_ids(obj:`list[int]`): The list of title ids.
- query_seq_len(obj:`int`): The input sequence query length.
- title_seq_len(obj:`int`): The input sequence title length.
"""
examples = []
for query, title in data:
query_tokens = jieba.lcut(query)
title_tokens = jieba.lcut(title)
query_ids = convert_tokens_to_ids(query_tokens, vocab)
title_ids = convert_tokens_to_ids(title_tokens, vocab)
examples.append([query_ids, title_ids, len(query_ids), len(title_ids)])
return examples
......@@ -15,6 +15,7 @@
from .chnsenticorp import *
from .dataset import *
from .glue import *
from .lcqmc import *
from .msra_ner import *
from .ptb import *
from .squad import *
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import collections
import io
import os
import warnings
from paddle.io import Dataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from .dataset import TSVDataset
__all__ = ['LCQMC']
class LCQMC(TSVDataset):
"""
LCQMC:A Large-scale Chinese Question Matching Corpus
More information please refer to `https://www.aclweb.org/anthology/C18-1166/`
"""
URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz"
MD5 = "62a7ba36f786a82ae59bbde0b0a9af0c"
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
os.path.join('lcqmc', 'train.tsv'),
'2193c022439b038ac12c0ae918b211a1', (0, 1, 2), 1),
'dev': SEGMENT_INFO(
os.path.join('lcqmc', 'dev.tsv'),
'c5dcba253cb4105d914964fd8b3c0e94', (0, 1, 2), 1),
'test': SEGMENT_INFO(
os.path.join('lcqmc', 'test.tsv'),
'8f4b71e15e67696cc9e112a459ec42bd', (0, 1, 2), 1)
}
def __init__(self,
segment='train',
root=None,
return_all_fields=False,
**kwargs):
if return_all_fields:
segments = copy.deepcopy(self.__class__.SEGMENTS)
segment_info = list(segments[segment])
segment_info[2] = None
segments[segment] = self.SEGMENT_INFO(*segment_info)
self.SEGMENTS = segments
self._get_data(root, segment, **kwargs)
def _get_data(self, root, segment, **kwargs):
default_root = DATA_HOME
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.format(
filename, self.__class__.__name__, default_root))
path = get_path_from_url(self.URL, default_root, self.MD5)
fullname = os.path.join(default_root, filename)
super(LCQMC, self).__init__(
fullname,
field_indices=field_indices,
num_discard_samples=num_discard_samples,
**kwargs)
def get_labels(self):
"""
Return labels of the LCQMC object.
"""
return ["0", "1"]
if __name__ == "__main__":
ds = LCQMC('train', return_all_fields=True)
for idx, data in enumerate(ds):
if idx >= 3:
break
print(data)
from .senta import Senta
from .ernie import Ernie
from .simnet import SimNet
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.transformers import *
import paddlenlp as nlp
class Ernie(nn.Layer):
def __init__(self, model_name, num_classes, task=None):
super().__init__()
model_name = model_name.lower()
self.task = task.lower()
if self.task == 'seq-cls':
required_names = list(ErnieForSequenceClassification.
pretrained_init_configuration.keys())
assert model_name in required_names, "model_name must be in %s, unknown %s ." (
required_names, model_name)
self.model = ErnieForSequenceClassification.from_pretrained(
model_name, num_classes=num_classes)
elif self.task == 'token-cls':
required_names = list(ErnieForTokenClassification.
pretrained_init_configuration.keys())
assert model_name in required_names, "model_name must be in %s, unknown %s ." (
required_names, model_name)
self.model = ErnieForTokenClassification.from_pretrained(
model_name, num_classes=num_classes)
elif self.task == 'qa':
required_names = list(
ErnieForQuestionAnswering.pretrained_init_configuration.keys())
assert model_name in required_names, "model_name must be in %s, unknown %s ." (
required_names, model_name)
self.model = ErnieForQuestionAnswering.from_pretrained(model_name)
elif self.task is None:
required_names = list(ErnieModel.pretrained_init_configuration.keys(
))
assert model_name in required_names, "model_name must be in %s, unknown %s ." (
required_names, model_name)
self.model = ErnieModel.from_pretrained(model_name)
else:
raise RuntimeError(
"Unknown task %s. Please make sure it to be one of seq-cls (it means sequence classifaction), "
"token-cls (it means token classifaction), qa (it means question answering) "
"or set it as None object." % task)
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
if self.task in ['seq-cls', 'token-cls']:
logits = self.model(input_ids, token_type_ids, position_ids,
attention_mask)
probs = F.softmax(logits, axis=-1)
return probs
elif self.task == 'qa':
start_logits, end_logits = self.model(input_ids, token_type_ids,
position_ids, attention_mask)
start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1)
start_probs = F.softmax(start_position, axis=-1)
end_probs = F.softmax(end_position, axis=-1)
return start_probs, end_probs
elif self.task is None:
sequence_output, pooled_output = self.model(
input_ids, token_type_ids, position_ids, attention_mask)
return sequence_output, pooled_output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -20,28 +21,110 @@ import paddlenlp as nlp
INF = 1. * 1e12
class Senta(nn.Layer):
def __init__(self,
network_name,
vocab_size,
num_classes,
emb_dim=128,
pad_token_id=0):
super().__init__()
network_name = network_name.lower()
if network_name == 'bow':
self.model = BoWModel(
vocab_size, num_classes, emb_dim, padding_idx=pad_token_id)
elif network_name == 'bigru':
self.model = GRUModel(
vocab_size,
num_classes,
emb_dim,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'bilstm':
self.model = LSTMModel(
vocab_size,
num_classes,
emb_dim,
direction='bidirectional',
padding_idx=pad_token_id)
elif network_name == 'bilstm_attn':
lstm_hidden_size = 196
attention = SelfInteractiveAttention(hidden_size=2 *
lstm_hidden_size)
self.model = BiLSTMAttentionModel(
attention_layer=attention,
vocab_size=vocab_size,
lstm_hidden_size=lstm_hidden_size,
num_classes=num_classes,
padding_idx=pad_token_id)
elif network_name == 'birnn':
self.model = RNNModel(
vocab_size,
num_classes,
emb_dim,
direction='bidrectional',
padding_idx=pad_token_id)
elif network_name == 'cnn':
self.model = CNNModel(
vocab_size, num_classes, emb_dim, padding_idx=pad_token_id)
elif network_name == 'gru':
self.model = GRUModel(
vocab_size,
num_classes,
emb_dim,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'lstm':
self.model = LSTMModel(
vocab_size,
num_classes,
emb_dim,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'rnn':
self.model = RNNModel(
vocab_size,
num_classes,
emb_dim,
direction='forward',
padding_idx=pad_token_id,
pooling_type='max')
elif network_name == 'textcnn':
self.model = TextCNNModel(
vocab_size, num_classes, emb_dim, padding_idx=pad_token_id)
else:
raise ValueError(
"Unknown network: %s, it must be one of bow, lstm, bilstm, cnn, gru, bigru, rnn, birnn, bilstm_attn and textcnn."
% network_name)
def forward(self, text, seq_len):
logits = self.model(text, seq_len)
probs = F.softmax(logits, axis=-1)
return probs
class BoWModel(nn.Layer):
"""
This class implements the Bag of Words Classification Network model to classify texts.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 128): The embedding dimension.
padding_idx (obj:`int`, optinal, defaults to 0) : The pad token index.
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
num_labels (obj:`int`): All the labels that the data has.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
padding_idx=0,
hidden_size=128,
......@@ -52,7 +135,7 @@ class BoWModel(nn.Layer):
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size)
self.fc2 = nn.Linear(hidden_size, fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_labels)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
......@@ -66,7 +149,7 @@ class BoWModel(nn.Layer):
fc1_out = paddle.tanh(self.fc1(encoded_text))
# Shape: (batch_size, fc_hidden_size)
fc2_out = paddle.tanh(self.fc2(fc1_out))
# Shape: (batch_size, num_labels)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc2_out)
return logits
......@@ -74,7 +157,7 @@ class BoWModel(nn.Layer):
class LSTMModel(nn.Layer):
def __init__(self,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
padding_idx=0,
lstm_hidden_size=198,
......@@ -96,7 +179,7 @@ class LSTMModel(nn.Layer):
dropout=dropout_rate,
pooling_type=pooling_type)
self.fc = nn.Linear(self.lstm_encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_labels)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
......@@ -107,7 +190,7 @@ class LSTMModel(nn.Layer):
text_repr = self.lstm_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(text_repr))
# Shape: (batch_size, num_labels)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
......@@ -115,7 +198,7 @@ class LSTMModel(nn.Layer):
class GRUModel(nn.Layer):
def __init__(self,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
padding_idx=0,
gru_hidden_size=198,
......@@ -137,7 +220,7 @@ class GRUModel(nn.Layer):
dropout=dropout_rate,
pooling_type=pooling_type)
self.fc = nn.Linear(self.gru_encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_labels)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
......@@ -148,7 +231,7 @@ class GRUModel(nn.Layer):
text_repr = self.gru_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(text_repr))
# Shape: (batch_size, num_labels)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
......@@ -156,7 +239,7 @@ class GRUModel(nn.Layer):
class RNNModel(nn.Layer):
def __init__(self,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
padding_idx=0,
rnn_hidden_size=198,
......@@ -178,7 +261,7 @@ class RNNModel(nn.Layer):
dropout=dropout_rate,
pooling_type=pooling_type)
self.fc = nn.Linear(self.rnn_encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_labels)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
......@@ -189,7 +272,7 @@ class RNNModel(nn.Layer):
text_repr = self.rnn_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(text_repr))
# Shape: (batch_size, num_labels)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
......@@ -198,7 +281,7 @@ class BiLSTMAttentionModel(nn.Layer):
def __init__(self,
attention_layer,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
lstm_hidden_size=196,
fc_hidden_size=96,
......@@ -226,7 +309,7 @@ class BiLSTMAttentionModel(nn.Layer):
else:
raise RuntimeError("Unknown attention type %s." %
attention_layer.__class__.__name__)
self.output_layer = nn.Linear(fc_hidden_size, num_labels)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
mask = text != self.padding_idx
......@@ -238,7 +321,7 @@ class BiLSTMAttentionModel(nn.Layer):
hidden, att_weights = self.attention(encoded_text, mask)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(hidden))
# Shape: (batch_size, num_labels)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
......@@ -248,7 +331,6 @@ class SelfAttention(nn.Layer):
A close implementation of attention network of ACL 2016 paper,
Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification (Zhou et al., 2016).
ref: https://www.aclweb.org/anthology/P16-2034/
Args:
hidden_size (obj:`int`): The number of expected features in the input x.
"""
......@@ -296,7 +378,6 @@ class SelfInteractiveAttention(nn.Layer):
"""
A close implementation of attention network of NAACL 2016 paper, Hierarchical Attention Networks for Document Classification (Yang et al., 2016).
ref: https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf
Args:
hidden_size (obj:`int`): The number of expected features in the input x.
"""
......@@ -345,10 +426,58 @@ class SelfInteractiveAttention(nn.Layer):
return reps, att_weight
class CNNModel(nn.Layer):
"""
This class implements the Convolution Neural Network model.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `CNNEncoder`.
The CNN has one convolution layer for each ngram filter size. Each convolution operation gives
out a vector of size num_filter. The number of times a convolution layer will be used
is `num_tokens - ngram_size + 1`. The corresponding maxpooling layer aggregates all these
outputs from the convolution layer and outputs the max.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 128): The embedding dimension.
padding_idx (obj:`int`, optinal, defaults to 0) : The pad token index.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
num_filter=128,
ngram_filter_sizes=(3, ),
fc_hidden_size=96):
super().__init__()
self.embedder = nn.Embedding(
vocab_size, emb_dim, padding_idx=padding_idx)
self.encoder = nlp.seq2vec.CNNEncoder(
emb_dim=emb_dim,
num_filter=num_filter,
ngram_filter_sizes=ngram_filter_sizes)
self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text)
# Shape: (batch_size, len(ngram_filter_sizes)*num_filter)
encoder_out = self.encoder(embedded_text)
encoder_out = paddle.tanh(encoder_out)
# Shape: (batch_size, fc_hidden_size)
fc_out = self.fc(encoder_out)
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
class TextCNNModel(nn.Layer):
"""
This class implements the Text Convolution Neural Network model.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `CNNEncoder`.
The CNN has one convolution layer for each ngram filter size. Each convolution operation gives
......@@ -357,21 +486,20 @@ class TextCNNModel(nn.Layer):
outputs from the convolution layer and outputs the max.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 128): The embedding dimension.
padding_idx (obj:`int`, optinal, defaults to 0) : The pad token index.
num_labels (obj:`int`): All the labels that the data has.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_labels,
num_classes,
emb_dim=128,
padding_idx=0,
num_filter=128,
ngram_filter_sizes=[1, 2, 3],
ngram_filter_sizes=(1, 2, 3),
fc_hidden_size=96):
super().__init__()
self.embedder = nn.Embedding(
......@@ -379,17 +507,18 @@ class TextCNNModel(nn.Layer):
self.encoder = nlp.seq2vec.CNNEncoder(
emb_dim=emb_dim,
num_filter=num_filter,
ngram_filter_sizes=(1, 2, 3),
output_dim=fc_hidden_size)
self.output_layer = nn.Linear(self.encoder.get_output_dim(), num_labels)
ngram_filter_sizes=ngram_filter_sizes)
self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, text, seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text)
# Shape: (batch_size, fc_hidden_size)
# Shape: (batch_size, len(ngram_filter_sizes)*num_filter)
encoder_out = self.encoder(embedded_text)
encoder_out = paddle.tanh(encoder_out)
# Shape: (batch_size, num_labels)
logits = self.output_layer(encoder_out)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(encoder_out))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
return logits
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlenlp as nlp
class SimNet(nn.Layer):
def __init__(self,
network,
vocab_size,
num_classes,
emb_dim=128,
pad_token_id=0):
super().__init__()
network = network.lower()
if network == 'bow':
self.model = BoWModel(
vocab_size, num_classes, emb_dim, padding_idx=pad_token_id)
elif network == 'cnn':
self.model = CNNModel(
vocab_size, num_classes, emb_dim, padding_idx=pad_token_id)
elif network == 'gru':
self.model = GRUModel(
vocab_size,
num_classes,
emb_dim,
direction='forward',
padding_idx=pad_token_id)
elif network == 'lstm':
self.model = LSTMModel(
vocab_size,
num_classes,
emb_dim,
direction='forward',
padding_idx=pad_token_id)
else:
raise ValueError(
"Unknown network: %s, it must be one of bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn."
% network)
def forward(self, query, title, query_seq_len=None, title_seq_len=None):
logits = self.model(query, title, query_seq_len, title_seq_len)
probs = F.softmax(logits, axis=-1)
return probs
class BoWModel(nn.Layer):
"""
This class implements the Bag of Words Classification Network model to classify texts.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 128): The embedding dimension.
padding_idx (obj:`int`, optinal, defaults to 0) : The pad token index.
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
fc_hidden_size=128):
super().__init__()
self.embedder = nn.Embedding(
vocab_size, emb_dim, padding_idx=padding_idx)
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
self.fc = nn.Linear(self.bow_encoder.get_output_dim() * 2,
fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, query, title, query_seq_len=None, title_seq_len=None):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_query = self.embedder(query)
embedded_title = self.embedder(title)
# Shape: (batch_size, embedding_dim)
summed_query = self.bow_encoder(embedded_query)
summed_title = self.bow_encoder(embedded_title)
encoded_query = paddle.tanh(summed_query)
encoded_title = paddle.tanh(summed_title)
# Shape: (batch_size, embedding_dim*2)
contacted = paddle.concat([encoded_query, encoded_title], axis=-1)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(contacted))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
# probs = F.softmax(logits, axis=-1)
return logits
class LSTMModel(nn.Layer):
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
lstm_hidden_size=128,
direction='forward',
lstm_layers=1,
dropout_rate=0.0,
pooling_type=None,
fc_hidden_size=128):
super().__init__()
self.embedder = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_dim,
padding_idx=padding_idx)
self.lstm_encoder = nlp.seq2vec.LSTMEncoder(
emb_dim,
lstm_hidden_size,
num_layers=lstm_layers,
direction=direction,
dropout=dropout_rate)
self.fc = nn.Linear(self.lstm_encoder.get_output_dim() * 2,
fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, query, title, query_seq_len, title_seq_len):
assert query_seq_len is not None and title_seq_len is not None
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_query = self.embedder(query)
embedded_title = self.embedder(title)
# Shape: (batch_size, lstm_hidden_size)
query_repr = self.lstm_encoder(
embedded_query, sequence_length=query_seq_len)
title_repr = self.lstm_encoder(
embedded_title, sequence_length=title_seq_len)
# Shape: (batch_size, 2*lstm_hidden_size)
contacted = paddle.concat([query_repr, title_repr], axis=-1)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(contacted))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
# probs = F.softmax(logits, axis=-1)
return logits
class GRUModel(nn.Layer):
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
gru_hidden_size=128,
direction='forward',
gru_layers=1,
dropout_rate=0.0,
pooling_type=None,
fc_hidden_size=96):
super().__init__()
self.embedder = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_dim,
padding_idx=padding_idx)
self.gru_encoder = nlp.seq2vec.GRUEncoder(
emb_dim,
gru_hidden_size,
num_layers=gru_layers,
direction=direction,
dropout=dropout_rate)
self.fc = nn.Linear(self.gru_encoder.get_output_dim() * 2,
fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, query, title, query_seq_len, title_seq_len):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_query = self.embedder(query)
embedded_title = self.embedder(title)
# Shape: (batch_size, gru_hidden_size)
query_repr = self.gru_encoder(
embedded_query, sequence_length=query_seq_len)
title_repr = self.gru_encoder(
embedded_title, sequence_length=title_seq_len)
# Shape: (batch_size, 2*gru_hidden_size)
contacted = paddle.concat([query_repr, title_repr], axis=-1)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(contacted))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
# probs = F.softmax(logits, axis=-1)
return logits
class CNNModel(nn.Layer):
"""
This class implements the
Convolution Neural Network model.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `CNNEncoder`.
The CNN has one convolution layer for each ngram filter size. Each convolution operation gives
out a vector of size num_filter. The number of times a convolution layer will be used
is `num_tokens - ngram_size + 1`. The corresponding maxpooling layer aggregates all these
outputs from the convolution layer and outputs the max.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 128): The embedding dimension.
padding_idx (obj:`int`, optinal, defaults to 0) : The pad token index.
num_classes (obj:`int`): All the labels that the data has.
"""
def __init__(self,
vocab_size,
num_classes,
emb_dim=128,
padding_idx=0,
num_filter=256,
ngram_filter_sizes=(3, ),
fc_hidden_size=128):
super().__init__()
self.padding_idx = padding_idx
self.embedder = nn.Embedding(
vocab_size, emb_dim, padding_idx=padding_idx)
self.encoder = nlp.seq2vec.CNNEncoder(
emb_dim=emb_dim,
num_filter=num_filter,
ngram_filter_sizes=ngram_filter_sizes)
self.fc = nn.Linear(self.encoder.get_output_dim() * 2, fc_hidden_size)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
def forward(self, query, title, query_seq_len=None, title_seq_len=None):
# Shape: (batch_size, num_tokens, embedding_dim)
embedded_query = self.embedder(query)
embedded_title = self.embedder(title)
# Shape: (batch_size, num_filter)
query_repr = self.encoder(embedded_query)
title_repr = self.encoder(embedded_title)
# Shape: (batch_size, 2*num_filter)
contacted = paddle.concat([query_repr, title_repr], axis=-1)
# Shape: (batch_size, fc_hidden_size)
fc_out = paddle.tanh(self.fc(contacted))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc_out)
# probs = F.softmax(logits, axis=-1)
return logits
......@@ -93,7 +93,7 @@ class ErniePretrainedModel(PretrainedModel):
model_config_file = "model_config.json"
pretrained_init_configuration = {
"ernie": {
"ernie-1.0": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
......@@ -120,7 +120,7 @@ class ErniePretrainedModel(PretrainedModel):
"vocab_size": 50006,
"pad_token_id": 0,
},
"ernie_v2_eng_base": {
"ernie-2.0-en": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
......@@ -133,9 +133,9 @@ class ErniePretrainedModel(PretrainedModel):
"vocab_size": 30522,
"pad_token_id": 0,
},
"ernie_v2_eng_large": {
"ernie-2.0-large-en": {
"attention_probs_dropout_prob": 0.1,
"intermediate_size": 4096, # special for ernie_v2_eng_large
"intermediate_size": 4096, # special for ernie-2.0-large-en
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
......@@ -151,14 +151,14 @@ class ErniePretrainedModel(PretrainedModel):
resource_files_names = {"model_state": "model_state.pdparams"}
pretrained_resource_files_map = {
"model_state": {
"ernie":
"ernie-1.0":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams",
"ernie_tiny":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/ernie_tiny.pdparams",
"ernie_v2_eng_base":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/ernie_v2_eng_base.pdparams",
"ernie_v2_eng_large":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_large/ernie_v2_eng_large.pdparams",
"ernie-2.0-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/ernie-2.0-en.pdparams",
"ernie-2.0-large-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_large/ernie-2.0-large-en.pdparams",
}
}
base_model_prefix = "ernie"
......@@ -271,18 +271,23 @@ class ErnieForSequenceClassification(ErniePretrainedModel):
class ErnieForQuestionAnswering(ErniePretrainedModel):
def __init__(self, ernie, dropout=None):
def __init__(self, ernie):
super(ErnieForQuestionAnswering, self).__init__()
self.ernie = ernie # allow ernie to be config
self.classifier = nn.Linear(self.ernie.config["hidden_size"], 2)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None):
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
sequence_output, _ = self.ernie(
input_ids,
token_type_ids=token_type_ids,
position_ids=None,
attention_mask=None)
position_ids=position_ids,
attention_mask=attention_mask)
logits = self.classifier(sequence_output)
logits = paddle.transpose(logits, perm=[2, 0, 1])
start_logits, end_logits = paddle.unstack(x=logits, axis=0)
......
......@@ -53,22 +53,22 @@ class ErnieTokenizer(PretrainedTokenizer):
resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained
pretrained_resource_files_map = {
"vocab_file": {
"ernie":
"ernie-1.0":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie/vocab.txt",
"ernie_v2_eng_base":
"ernie-2.0-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_base/vocab.txt",
"ernie_v2_eng_large":
"ernie-2.0-large-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_v2_large/vocab.txt",
}
}
pretrained_init_configuration = {
"ernie": {
"ernie-1.0": {
"do_lower_case": True
},
"ernie_v2_eng_base": {
"ernie-2.0-en": {
"do_lower_case": True
},
"ernie_v2_eng_large": {
"ernie-2.0-large-en": {
"do_lower_case": True
},
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册