未验证 提交 fa0ef7fb 编写于 作者: K kinghuin 提交者: GitHub

new chunk_eval (#5138) (#5156)

* new chunk_eval

* np first
上级 08c08300
......@@ -95,8 +95,7 @@ def train(args):
learning_rate=args.base_lr, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
chunk_evaluator = ChunkEvaluator(
int(math.ceil((train_dataset.num_labels + 1) / 2.0)),
"IOB") # + 1 for START and STOP
label_list=train_dataset.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator)
if args.init_checkpoint:
model.load(args.init_checkpoint)
......
......@@ -165,7 +165,8 @@ if __name__ == '__main__':
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
crf_loss = LinearChainCrfLoss(network.crf.transitions)
chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB')
chunk_evaluator = ChunkEvaluator(
label_list=train_ds.label_vocab.keys(), suffix=True)
model.prepare(optimizer, crf_loss, chunk_evaluator)
model.fit(train_data=train_loader,
......
......@@ -154,7 +154,7 @@ if __name__ == '__main__':
model = ErnieForTokenClassification.from_pretrained(
"ernie-1.0", num_classes=train_ds.label_num)
metric = ChunkEvaluator((train_ds.label_num + 2) // 2, "IOB")
metric = ChunkEvaluator(label_list=train_ds.label_vocab.keys(), suffix=True)
loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
optimizer = paddle.optimizer.AdamW(
learning_rate=2e-5, parameters=model.parameters())
......
......@@ -52,20 +52,20 @@ python -u ./run_msra_ner.py \
训练过程将按照 `logging_steps``save_steps` 的设置打印如下日志:
```
global step 996, epoch: 1, batch: 344, loss: 0.038471, speed: 4.72 step/s
global step 997, epoch: 1, batch: 345, loss: 0.032820, speed: 4.82 step/s
global step 998, epoch: 1, batch: 346, loss: 0.008144, speed: 4.69 step/s
global step 999, epoch: 1, batch: 347, loss: 0.031425, speed: 4.36 step/s
global step 1000, epoch: 1, batch: 348, loss: 0.073151, speed: 4.59 step/s
eval loss: 0.019874, precision: 0.991670, recall: 0.991930, f1: 0.991800
global step 1496, epoch: 2, batch: 192, loss: 0.010747, speed: 4.77 step/s
global step 1497, epoch: 2, batch: 193, loss: 0.004837, speed: 4.46 step/s
global step 1498, epoch: 2, batch: 194, loss: 0.011281, speed: 4.24 step/s
global step 1499, epoch: 2, batch: 195, loss: 0.005711, speed: 4.73 step/s
global step 1500, epoch: 2, batch: 196, loss: 0.003150, speed: 4.52 step/s
eval loss: 0.010307, precision: 0.884222, recall: 0.903190, f1: 0.893605
```
使用以上命令进行单卡 Fine-tuning ,在验证集上有如下结果:
Metric | Result |
------------------------------|-------------|
precision | 0.992903 |
recall | 0.991823 |
f1 | 0.992363 |
precision | 0.884222 |
recall | 0.903190 |
f1 | 0.893605 |
## 参考
......
......@@ -313,7 +313,8 @@ def do_train(args):
])
loss_fct = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
metric = ChunkEvaluator(int(math.ceil((label_num + 1) / 2.0)), "IOB")
metric = ChunkEvaluator(label_list=train_dataset.get_labels())
global_step = 0
tic_train = time.time()
......
from collections import defaultdict
import numpy as np
import paddle
from seqeval.metrics.sequence_labeling import get_entities
def extract_tp_actual_correct(y_true, y_pred, suffix, *args):
entities_true = defaultdict(set)
entities_pred = defaultdict(set)
for type_name, start, end in get_entities(y_true, suffix):
entities_true[type_name].add((start, end))
for type_name, start, end in get_entities(y_pred, suffix):
entities_pred[type_name].add((start, end))
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
tp_sum = np.array([], dtype=np.int32)
pred_sum = np.array([], dtype=np.int32)
true_sum = np.array([], dtype=np.int32)
for type_name in target_names:
entities_true_type = entities_true.get(type_name, set())
entities_pred_type = entities_pred.get(type_name, set())
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
pred_sum = np.append(pred_sum, len(entities_pred_type))
true_sum = np.append(true_sum, len(entities_true_type))
return pred_sum, tp_sum, true_sum
class ChunkEvaluator(paddle.metric.Metric):
......@@ -7,32 +33,35 @@ class ChunkEvaluator(paddle.metric.Metric):
It is often used in sequence tagging tasks, such as Named Entity Recognition(NER).
Args:
num_chunk_types (int): The number of chunk types.
chunk_scheme (str): Indicate the tagging schemes used here. The value must
be IOB, IOE, IOBES or plain.
excluded_chunk_types (list, optional): Indicate the chunk types shouldn't
be taken into account. It should be a list of chunk type ids(integer).
Default None.
label_list (list): The label list.
suffix (bool): if set True, the label ends with '-B', '-I', '-E' or '-S', else the label starts with them.
"""
def __init__(self, num_chunk_types, chunk_scheme,
excluded_chunk_types=None):
def __init__(self, label_list, suffix=False):
super(ChunkEvaluator, self).__init__()
self.num_chunk_types = num_chunk_types
self.chunk_scheme = chunk_scheme
self.excluded_chunk_types = excluded_chunk_types
self.id2label_dict = dict(enumerate(label_list))
self.suffix = suffix
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
def compute(self, inputs, lengths, predictions, labels):
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = paddle.metric.chunk_eval(
predictions,
labels,
chunk_scheme=self.chunk_scheme,
num_chunk_types=self.num_chunk_types,
excluded_chunk_types=self.excluded_chunk_types,
seq_length=lengths)
labels = labels.numpy()
predictions = predictions.numpy()
unpad_labels = [[
self.id2label_dict[index]
for index in labels[sent_index][:lengths[sent_index]]
] for sent_index in range(len(lengths))]
unpad_predictions = [[
self.id2label_dict.get(index, "O")
for index in predictions[sent_index][:lengths[sent_index]]
] for sent_index in range(len(lengths))]
pred_sum, tp_sum, true_sum = extract_tp_actual_correct(
unpad_labels, unpad_predictions, self.suffix)
num_correct_chunks = paddle.to_tensor([tp_sum.sum()])
num_infer_chunks = paddle.to_tensor([pred_sum.sum()])
num_label_chunks = paddle.to_tensor([true_sum.sum()])
return num_infer_chunks, num_label_chunks, num_correct_chunks
......
......@@ -3,3 +3,4 @@ jieba
h5py
colorlog
colorama
seqeval
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册