From fa0ef7fb6f5713e7589a5db717d2c34cb3a957aa Mon Sep 17 00:00:00 2001 From: kinghuin Date: Tue, 29 Dec 2020 19:26:56 +0800 Subject: [PATCH] new chunk_eval (#5138) (#5156) * new chunk_eval * np first --- PaddleNLP/examples/lexical_analysis/train.py | 3 +- .../express_ner/run_bigru_crf.py | 3 +- .../express_ner/run_ernie.py | 2 +- .../msra_ner/README.md | 18 ++--- .../msra_ner/run_msra_ner.py | 3 +- PaddleNLP/paddlenlp/metrics/chunk.py | 65 ++++++++++++++----- PaddleNLP/requirements.txt | 1 + 7 files changed, 63 insertions(+), 32 deletions(-) diff --git a/PaddleNLP/examples/lexical_analysis/train.py b/PaddleNLP/examples/lexical_analysis/train.py index 1adecda6..ffda35b5 100644 --- a/PaddleNLP/examples/lexical_analysis/train.py +++ b/PaddleNLP/examples/lexical_analysis/train.py @@ -95,8 +95,7 @@ def train(args): learning_rate=args.base_lr, parameters=model.parameters()) crf_loss = LinearChainCrfLoss(network.crf.transitions) chunk_evaluator = ChunkEvaluator( - int(math.ceil((train_dataset.num_labels + 1) / 2.0)), - "IOB") # + 1 for START and STOP + label_list=train_dataset.label_vocab.keys(), suffix=True) model.prepare(optimizer, crf_loss, chunk_evaluator) if args.init_checkpoint: model.load(args.init_checkpoint) diff --git a/PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py b/PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py index 8cc4105e..f8969230 100644 --- a/PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py +++ b/PaddleNLP/examples/named_entity_recognition/express_ner/run_bigru_crf.py @@ -165,7 +165,8 @@ if __name__ == '__main__': optimizer = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) crf_loss = LinearChainCrfLoss(network.crf.transitions) - chunk_evaluator = ChunkEvaluator((train_ds.label_num + 2) // 2, 'IOB') + chunk_evaluator = ChunkEvaluator( + label_list=train_ds.label_vocab.keys(), suffix=True) model.prepare(optimizer, crf_loss, chunk_evaluator) model.fit(train_data=train_loader, diff --git a/PaddleNLP/examples/named_entity_recognition/express_ner/run_ernie.py b/PaddleNLP/examples/named_entity_recognition/express_ner/run_ernie.py index e2dd2a1c..33389731 100644 --- a/PaddleNLP/examples/named_entity_recognition/express_ner/run_ernie.py +++ b/PaddleNLP/examples/named_entity_recognition/express_ner/run_ernie.py @@ -154,7 +154,7 @@ if __name__ == '__main__': model = ErnieForTokenClassification.from_pretrained( "ernie-1.0", num_classes=train_ds.label_num) - metric = ChunkEvaluator((train_ds.label_num + 2) // 2, "IOB") + metric = ChunkEvaluator(label_list=train_ds.label_vocab.keys(), suffix=True) loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label) optimizer = paddle.optimizer.AdamW( learning_rate=2e-5, parameters=model.parameters()) diff --git a/PaddleNLP/examples/named_entity_recognition/msra_ner/README.md b/PaddleNLP/examples/named_entity_recognition/msra_ner/README.md index 8a3e3ee6..3f570683 100644 --- a/PaddleNLP/examples/named_entity_recognition/msra_ner/README.md +++ b/PaddleNLP/examples/named_entity_recognition/msra_ner/README.md @@ -52,20 +52,20 @@ python -u ./run_msra_ner.py \ 训练过程将按照 `logging_steps` 和 `save_steps` 的设置打印如下日志: ``` -global step 996, epoch: 1, batch: 344, loss: 0.038471, speed: 4.72 step/s -global step 997, epoch: 1, batch: 345, loss: 0.032820, speed: 4.82 step/s -global step 998, epoch: 1, batch: 346, loss: 0.008144, speed: 4.69 step/s -global step 999, epoch: 1, batch: 347, loss: 0.031425, speed: 4.36 step/s -global step 1000, epoch: 1, batch: 348, loss: 0.073151, speed: 4.59 step/s -eval loss: 0.019874, precision: 0.991670, recall: 0.991930, f1: 0.991800 +global step 1496, epoch: 2, batch: 192, loss: 0.010747, speed: 4.77 step/s +global step 1497, epoch: 2, batch: 193, loss: 0.004837, speed: 4.46 step/s +global step 1498, epoch: 2, batch: 194, loss: 0.011281, speed: 4.24 step/s +global step 1499, epoch: 2, batch: 195, loss: 0.005711, speed: 4.73 step/s +global step 1500, epoch: 2, batch: 196, loss: 0.003150, speed: 4.52 step/s +eval loss: 0.010307, precision: 0.884222, recall: 0.903190, f1: 0.893605 ``` 使用以上命令进行单卡 Fine-tuning ,在验证集上有如下结果: Metric | Result | ------------------------------|-------------| -precision | 0.992903 | -recall | 0.991823 | -f1 | 0.992363 | +precision | 0.884222 | +recall | 0.903190 | +f1 | 0.893605 | ## 参考 diff --git a/PaddleNLP/examples/named_entity_recognition/msra_ner/run_msra_ner.py b/PaddleNLP/examples/named_entity_recognition/msra_ner/run_msra_ner.py index 3c68f524..943f6aee 100644 --- a/PaddleNLP/examples/named_entity_recognition/msra_ner/run_msra_ner.py +++ b/PaddleNLP/examples/named_entity_recognition/msra_ner/run_msra_ner.py @@ -313,7 +313,8 @@ def do_train(args): ]) loss_fct = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label) - metric = ChunkEvaluator(int(math.ceil((label_num + 1) / 2.0)), "IOB") + + metric = ChunkEvaluator(label_list=train_dataset.get_labels()) global_step = 0 tic_train = time.time() diff --git a/PaddleNLP/paddlenlp/metrics/chunk.py b/PaddleNLP/paddlenlp/metrics/chunk.py index e592b485..a3ae5dbe 100644 --- a/PaddleNLP/paddlenlp/metrics/chunk.py +++ b/PaddleNLP/paddlenlp/metrics/chunk.py @@ -1,5 +1,31 @@ +from collections import defaultdict + import numpy as np import paddle +from seqeval.metrics.sequence_labeling import get_entities + + +def extract_tp_actual_correct(y_true, y_pred, suffix, *args): + entities_true = defaultdict(set) + entities_pred = defaultdict(set) + for type_name, start, end in get_entities(y_true, suffix): + entities_true[type_name].add((start, end)) + for type_name, start, end in get_entities(y_pred, suffix): + entities_pred[type_name].add((start, end)) + + target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys())) + + tp_sum = np.array([], dtype=np.int32) + pred_sum = np.array([], dtype=np.int32) + true_sum = np.array([], dtype=np.int32) + for type_name in target_names: + entities_true_type = entities_true.get(type_name, set()) + entities_pred_type = entities_pred.get(type_name, set()) + tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type)) + pred_sum = np.append(pred_sum, len(entities_pred_type)) + true_sum = np.append(true_sum, len(entities_true_type)) + + return pred_sum, tp_sum, true_sum class ChunkEvaluator(paddle.metric.Metric): @@ -7,32 +33,35 @@ class ChunkEvaluator(paddle.metric.Metric): It is often used in sequence tagging tasks, such as Named Entity Recognition(NER). Args: - num_chunk_types (int): The number of chunk types. - chunk_scheme (str): Indicate the tagging schemes used here. The value must - be IOB, IOE, IOBES or plain. - excluded_chunk_types (list, optional): Indicate the chunk types shouldn't - be taken into account. It should be a list of chunk type ids(integer). - Default None. + label_list (list): The label list. + suffix (bool): if set True, the label ends with '-B', '-I', '-E' or '-S', else the label starts with them. """ - def __init__(self, num_chunk_types, chunk_scheme, - excluded_chunk_types=None): + def __init__(self, label_list, suffix=False): super(ChunkEvaluator, self).__init__() - self.num_chunk_types = num_chunk_types - self.chunk_scheme = chunk_scheme - self.excluded_chunk_types = excluded_chunk_types + self.id2label_dict = dict(enumerate(label_list)) + self.suffix = suffix self.num_infer_chunks = 0 self.num_label_chunks = 0 self.num_correct_chunks = 0 def compute(self, inputs, lengths, predictions, labels): - precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = paddle.metric.chunk_eval( - predictions, - labels, - chunk_scheme=self.chunk_scheme, - num_chunk_types=self.num_chunk_types, - excluded_chunk_types=self.excluded_chunk_types, - seq_length=lengths) + labels = labels.numpy() + predictions = predictions.numpy() + unpad_labels = [[ + self.id2label_dict[index] + for index in labels[sent_index][:lengths[sent_index]] + ] for sent_index in range(len(lengths))] + unpad_predictions = [[ + self.id2label_dict.get(index, "O") + for index in predictions[sent_index][:lengths[sent_index]] + ] for sent_index in range(len(lengths))] + + pred_sum, tp_sum, true_sum = extract_tp_actual_correct( + unpad_labels, unpad_predictions, self.suffix) + num_correct_chunks = paddle.to_tensor([tp_sum.sum()]) + num_infer_chunks = paddle.to_tensor([pred_sum.sum()]) + num_label_chunks = paddle.to_tensor([true_sum.sum()]) return num_infer_chunks, num_label_chunks, num_correct_chunks diff --git a/PaddleNLP/requirements.txt b/PaddleNLP/requirements.txt index a086afa5..67002515 100644 --- a/PaddleNLP/requirements.txt +++ b/PaddleNLP/requirements.txt @@ -3,3 +3,4 @@ jieba h5py colorlog colorama +seqeval \ No newline at end of file -- GitLab