未验证 提交 99d39e52 编写于 作者: N Noel 提交者: GitHub

Add express task (#5024)

* Add Express Example

* Add Express Data

* Add Ernie for Express Example

* add the express for the paddlenlp
Co-authored-by: Nwanghuijuan03 <wanghuijuan03@baidu.com>
上级 e8139e93
0 P-B
1 P-I
2 T-B
3 T-I
4 A1-B
5 A1-I
6 A2-B
7 A2-I
8 A3-B
9 A3-I
10 A4-B
11 A4-I
12 O
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.layers import LinearChainCrf, ViterbiDecoder, LinearChainCrfLoss
from paddlenlp.metrics import ChunkEvaluator
def parse_decodes(ds, decodes, lens):
decodes = [x for batch in decodes for x in batch]
lens = [x for batch in lens for x in batch]
id_word = dict(zip(ds.word_vocab.values(), ds.word_vocab.keys()))
id_label = dict(zip(ds.label_vocab.values(), ds.label_vocab.keys()))
outputs = []
for idx, end in enumerate(lens):
sent = [id_word[x] for x in ds.word_ids[idx][:end]]
tags = [id_label[x] for x in decodes[idx][:end]]
sent_out = []
tags_out = []
words = ""
for s, t in zip(sent, tags):
if t.endswith('-B') or t == 'O':
if len(words):
sent_out.append(words)
tags_out.append(t.split('-')[0])
words = s
else:
words += s
if len(sent_out) < len(tags_out):
sent_out.append(words)
outputs.append(''.join(
[str((s, t)) for s, t in zip(sent_out, tags_out)]))
return outputs
def convert_tokens_to_ids(tokens, vocab, oov_token=None):
token_ids = []
oov_id = vocab.get(oov_token) if oov_token else None
for token in tokens:
token_id = vocab.get(token, oov_id)
token_ids.append(token_id)
return token_ids
def load_dict(dict_path):
vocab = {}
for line in open(dict_path, 'r', encoding='utf-8'):
value, key = line.strip('\n').split('\t')
vocab[key] = int(value)
return vocab
class ExpressDataset(paddle.io.Dataset):
def __init__(self, data_path):
self.word_vocab = load_dict('./conf/word.dic')
self.label_vocab = load_dict('./conf/tag.dic')
self.word_ids = []
self.label_ids = []
with open(data_path, 'r', encoding='utf-8') as fp:
next(fp)
for line in fp.readlines():
words, labels = line.strip('\n').split('\t')
words = words.split('\002')
labels = labels.split('\002')
sub_word_ids = convert_tokens_to_ids(words, self.word_vocab,
'OOV')
sub_label_ids = convert_tokens_to_ids(labels, self.label_vocab,
'O')
self.word_ids.append(sub_word_ids)
self.label_ids.append(sub_label_ids)
self.word_num = max(self.word_vocab.values()) + 1
self.label_num = max(self.label_vocab.values()) + 1
def __len__(self):
return len(self.word_ids)
def __getitem__(self, index):
return self.word_ids[index], len(self.word_ids[index]), self.label_ids[
index]
class BiGRUWithCRF(nn.Layer):
def __init__(self, emb_size, hidden_size, word_num, label_num):
super(BiGRUWithCRF, self).__init__()
self.word_emb = nn.Embedding(word_num, emb_size)
self.gru = nn.GRU(emb_size,
hidden_size,
num_layers=2,
direction='bidirectional')
self.fc = nn.Linear(hidden_size * 2, label_num + 2) # BOS EOS
self.crf = LinearChainCrf(label_num)
self.decoder = ViterbiDecoder(self.crf.transitions)
def forward(self, x, lens):
embs = self.word_emb(x)
output, _ = self.gru(embs)
output = self.fc(output)
_, pred = self.decoder(output, lens)
return output, lens, pred
if __name__ == '__main__':
paddle.set_device('gpu')
train_ds = ExpressDataset('./data/train.txt')
dev_ds = ExpressDataset('./data/dev.txt')
test_ds = ExpressDataset('./data/test.txt')
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=train_ds.word_vocab.get('OOV')),
Stack(),
Pad(axis=0, pad_val=train_ds.label_vocab.get('O'))
): fn(samples)
train_loader = paddle.io.DataLoader(
dataset=train_ds,
batch_size=200,
shuffle=True,
drop_last=True,
return_list=True,
collate_fn=batchify_fn)
dev_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_size=200,
drop_last=True,
return_list=True,
collate_fn=batchify_fn)
test_loader = paddle.io.DataLoader(
dataset=test_ds,
batch_size=200,
drop_last=True,
return_list=True,
collate_fn=batchify_fn)
network = BiGRUWithCRF(300, 300, train_ds.word_num, train_ds.label_num)
model = paddle.Model(network)
optimizer = paddle.optimizer.Adam(
learning_rate=0.002, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB')
model.prepare(optimizer, crf_loss, chunk_evaluator)
model.fit(train_data=train_loader,
eval_data=dev_loader,
epochs=10,
save_dir='./results',
log_freq=1)
model.evaluate(eval_data=test_loader)
outputs, lens, decodes = model.predict(test_data=test_loader)
preds = parse_decodes(test_ds, decodes, lens)
print('\n'.join(preds[:10]))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import paddle
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import ErnieTokenizer, ErniePretrainedModel, ErnieForTokenClassification
from paddlenlp.metrics import ChunkEvaluator
def parse_decodes(ds, decodes, lens):
decodes = [x for batch in decodes for x in batch]
lens = [x for batch in lens for x in batch]
id_label = dict(zip(ds.label_vocab.values(), ds.label_vocab.keys()))
outputs = []
for idx, end in enumerate(lens):
sent = ds.word_ids[idx][:end]
tags = [id_label[x] for x in decodes[idx][1:end]]
sent_out = []
tags_out = []
words = ""
for s, t in zip(sent, tags):
if t.endswith('-B') or t == 'O':
if len(words):
sent_out.append(words)
tags_out.append(t.split('-')[0])
words = s
else:
words += s
if len(sent_out) < len(tags_out):
sent_out.append(words)
outputs.append(''.join(
[str((s, t)) for s, t in zip(sent_out, tags_out)]))
return outputs
@paddle.no_grad()
def evaluate(model, metric, data_loader):
model.eval()
metric.reset()
for input_ids, seg_ids, lens, labels in data_loader:
logits = model(input_ids, seg_ids)
preds = paddle.argmax(logits, axis=-1)
n_infer, n_label, n_correct = metric.compute(None, lens, preds, labels)
metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())
precision, recall, f1_score = metric.accumulate()
print("eval precision: %f - recall: %f - f1: %f" %
(precision, recall, f1_score))
def predict(model, data_loader, ds):
pred_list = []
len_list = []
for input_ids, seg_ids, lens, labels in data_loader:
logits = model(input_ids, seg_ids)
pred = paddle.argmax(logits, axis=-1)
pred_list.append(pred.numpy())
len_list.append(lens.numpy())
preds = parse_decodes(ds, pred_list, len_list)
print('\n'.join(preds[:10]))
def convert_example(example, tokenizer, label_vocab):
tokens, labels = example
tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(tokens)
lens = len(input_ids)
labels = ['O'] + labels + ['O']
labels = [label_vocab[x] for x in labels]
return input_ids, segment_ids, lens, labels
def load_dict(dict_path):
vocab = {}
for line in open(dict_path, 'r', encoding='utf-8'):
value, key = line.strip('\n').split('\t')
vocab[key] = int(value)
return vocab
class ExpressDataset(paddle.io.Dataset):
def __init__(self, data_path):
self.word_vocab = load_dict('./conf/word.dic')
self.label_vocab = load_dict('./conf/tag.dic')
self.word_ids = []
self.label_ids = []
with open(data_path, 'r', encoding='utf-8') as fp:
next(fp)
for line in fp.readlines():
words, labels = line.strip('\n').split('\t')
words = words.split('\002')
labels = labels.split('\002')
self.word_ids.append(words)
self.label_ids.append(labels)
self.word_num = max(self.word_vocab.values()) + 1
self.label_num = max(self.label_vocab.values()) + 1
def __len__(self):
return len(self.word_ids)
def __getitem__(self, index):
return self.word_ids[index], self.label_ids[index]
if __name__ == '__main__':
paddle.set_device('gpu')
train_ds = ExpressDataset('./data/train.txt')
dev_ds = ExpressDataset('./data/dev.txt')
test_ds = ExpressDataset('./data/test.txt')
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
trans_func = partial(
convert_example, tokenizer=tokenizer, label_vocab=train_ds.label_vocab)
ignore_label = -1
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),
Stack(),
Pad(axis=0, pad_val=ignore_label)
): fn(list(map(trans_func, samples)))
train_loader = paddle.io.DataLoader(
dataset=train_ds,
batch_size=200,
shuffle=True,
return_list=True,
collate_fn=batchify_fn)
dev_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_size=200,
return_list=True,
collate_fn=batchify_fn)
test_loader = paddle.io.DataLoader(
dataset=test_ds,
batch_size=200,
return_list=True,
collate_fn=batchify_fn)
model = ErnieForTokenClassification.from_pretrained(
"ernie-1.0", num_classes=train_ds.label_num)
metric = ChunkEvaluator((train_ds.label_num + 2) // 2, "IOB")
loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
optimizer = paddle.optimizer.AdamW(
learning_rate=2e-5, parameters=model.parameters())
step = 0
for epoch in range(10):
model.train()
for idx, (input_ids, segment_ids, length,
labels) in enumerate(train_loader):
logits = model(input_ids, segment_ids).reshape(
[-1, train_ds.label_num])
loss = paddle.mean(loss_fn(logits, labels.reshape([-1])))
loss.backward()
optimizer.step()
optimizer.clear_gradients()
step += 1
print("epoch:%d - step:%d - loss: %f" % (epoch, step, loss))
evaluate(model, metric, dev_loader)
paddle.save(model.state_dict(),
'./ernie_result/model_%d.pdparams' % step)
pred = predict(model, test_loader, test_ds)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import ErnieTokenizer, ErniePretrainedModel
from paddlenlp.layers.crf import LinearChainCrf, LinearChainCrfLoss, ViterbiDecoder
from paddlenlp.metrics import ChunkEvaluator
from paddle.static import InputSpec
from functools import partial
def parse_decodes(ds, decodes, lens):
decodes = [x for batch in decodes for x in batch]
lens = [x for batch in lens for x in batch]
id_label = dict(zip(ds.label_vocab.values(), ds.label_vocab.keys()))
outputs = []
for idx, end in enumerate(lens):
sent = ds.word_ids[idx][:end]
tags = [id_label[x] for x in decodes[idx][1:end]]
sent_out = []
tags_out = []
words = ""
for s, t in zip(sent, tags):
if t.endswith('-B') or t == 'O':
if len(words):
sent_out.append(words)
tags_out.append(t.split('-')[0])
words = s
else:
words += s
if len(sent_out) < len(tags_out):
sent_out.append(words)
outputs.append(''.join([str((s, t)) for s, t in zip(sent_out, tags_out)]))
return outputs
def convert_example(example, tokenizer, label_vocab):
tokens, labels = example
tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(tokens)
lens = len(input_ids)
labels = ['O'] + labels + ['O']
labels = [label_vocab[x] for x in labels]
return input_ids, segment_ids, lens, labels
def load_dict(dict_path):
vocab = {}
for line in open(dict_path, 'r', encoding='utf-8'):
value, key = line.strip('\n').split('\t')
vocab[key] = int(value)
return vocab
class ExpressDataset(paddle.io.Dataset):
def __init__(self, data_path):
self.word_vocab = load_dict('./conf/word.dic')
self.label_vocab = load_dict('./conf/tag.dic')
self.word_ids = []
self.label_ids = []
with open(data_path, 'r', encoding='utf-8') as fp:
next(fp)
for line in fp.readlines():
words, labels = line.strip('\n').split('\t')
words = words.split('\002')
labels = labels.split('\002')
self.word_ids.append(words)
self.label_ids.append(labels)
self.word_num = max(self.word_vocab.values()) + 1
self.label_num = max(self.label_vocab.values()) + 1
def __len__(self):
return len(self.word_ids)
def __getitem__(self, index):
return self.word_ids[index], self.label_ids[index]
class ErnieForTokenClassification(ErniePretrainedModel):
def __init__(self, ernie, num_classes=2, dropout=None):
super(ErnieForTokenClassification, self).__init__()
self.num_classes = num_classes
self.ernie = ernie
self.dropout = nn.Dropout(self.ernie.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.ernie.config["hidden_size"], num_classes)
self.apply(self.init_weights)
def forward(self,
input_ids,
token_type_ids=None,
lens=None,
position_ids=None,
attention_mask=None):
sequence_output, _ = self.ernie(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits, lens, paddle.argmax(logits, axis=-1)
class ErnieCRF(ErnieForTokenClassification):
def __init__(self, ernie, num_classes=2, crf_lr=1.0, dropout=None):
super(ErnieCRF, self).__init__(ernie, num_classes, dropout)
self.crf = LinearChainCrf(num_classes, crf_lr, False)
if __name__ == '__main__':
paddle.set_device('gpu')
train_ds = ExpressDataset('./data/train.txt')
dev_ds = ExpressDataset('./data/dev.txt')
test_ds = ExpressDataset('./data/test.txt')
tokenizer = ErnieTokenizer.from_pretrained('ernie')
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_vocab=train_ds.label_vocab)
ignore_label = -100
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),
Pad(axis=0, pad_val=tokenizer.vocab[tokenizer.pad_token]),
Stack(),
Pad(axis=0, pad_val=ignore_label)
): fn(list(map(trans_func, samples)))
train_loader = paddle.io.DataLoader(
dataset=train_ds,
batch_size=200,
shuffle=True,
return_list=True,
collate_fn=batchify_fn)
dev_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_size=200,
return_list=True,
collate_fn=batchify_fn)
test_loader = paddle.io.DataLoader(
dataset=test_ds,
batch_size=200,
return_list=True,
collate_fn=batchify_fn)
model = ErnieCRF.from_pretrained(
'ernie', num_classes=train_ds.label_num)
loss = LinearChainCrfLoss(transitions=model.crf.transitions)
decoder = ViterbiDecoder(transitions=model.crf.transitions)
metric = ChunkEvaluator((train_ds.label_num + 2) // 2, "IOB")
inputs = [InputSpec([None, None], dtype='int64', name='input_ids'),
InputSpec([None, None], dtype='int64', name='token_type_ids'),
InputSpec([None, None], dtype='int64', name='lens')]
model = paddle.Model(model, inputs)
optimizer = paddle.optimizer.AdamW(learning_rate=2e-5,parameters=model.parameters())
model.prepare(optimizer, loss, metric)
model.fit(train_data=train_loader,
eval_data=dev_loader,
epochs=10,
save_dir='./results',
log_freq=1,
save_freq=10000)
model.evaluate(eval_data=test_loader)
outputs, lens, decodes = model.predict(test_data=test_loader)
pred = parse_decodes(test_ds, decodes, lens)
print('\n'.join(pred[:10]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册