未验证 提交 a2721553 编写于 作者: L LiuChiachi 提交者: GitHub

Update metric bleu and rouge-l, inherited from Metric (#4982)

* update rouge-l, move default_trans_func to utils

* fix error info bugs

* add score method for metric bleu and rouge-l

* update bleu doc
上级 d3029c01
......@@ -14,5 +14,6 @@
from .perplexity import Perplexity
from .chunk import ChunkEvaluator
from .bleu import BLEU
from .bleu import BLEU, BLEUForDuReader
from .rouge import RougeL, RougeLForDuReader
from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman
......@@ -16,6 +16,12 @@ import math
import sys
from collections import defaultdict
import paddle
from .utils import default_trans_func
__all__ = ["BLEU", "BLEUForDuReader"]
def get_match_size(cand_ngram, refs_ngram):
ref_set = defaultdict(int)
......@@ -48,12 +54,22 @@ def get_ngram(sent, n_size, label=None):
return ngram_list
class BLEU(object):
class BLEU(paddle.metric.Metric):
r'''
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of
text which has been machine-translated from one natural language to another. This metric
uses a modified form of precision to compare a candidate translation against multiple
reference translations.
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the
quality of text which has been machine-translated from one natural language
to another. This metric uses a modified form of precision to compare a
candidate translation against multiple reference translations.
BLEU could be used as `paddle.metric.Metric` class, or an ordinary
class. When BLEU is used as `paddle.metric.Metric` class. A function is
needed that transforms the network output to reference string list, and
transforms the label to candidate string. By default, a default function
`default_trans_func` is provided, which gets target sequence id by
calculating the maximum probability of each step. In this case, user must
provide `vocab`. It should be noted that the BLEU here is different from
the BLEU calculated in prediction, and it is only for observation during
training and evaluation.
.. math::
......@@ -68,24 +84,72 @@ class BLEU(object):
where `c` is the length of candidate sentence, and 'r' is the length of refrence sentence.
Args:
n_size (int): Number of gram for BLEU metric. Default: 4.
weights (list, optional): The weights of precision of each gram. Default: None.
trans_func (callable, optional): `trans_func` transforms the network
output to string to calculate.
vocab (dict|paddlenlp.data.vocab, optional): Vocab for target language.
If `trans_func` is None and BLEU is used as `paddle.metric.Metric`
instance, `default_trans_func` will be performed and `vocab` must
be provided.
n_size (int, optional): Number of gram for BLEU metric. Default: 4.
weights (list, optional): The weights of precision of each gram.
Default: None.
name (str, optional): Name of `paddle.metric.Metric` instance.
Default: "bleu".
Examples:
1. Using as a general evaluation object.
.. code-block:: python
from paddlenlp.metrics import BLEU
bleu = BLEU()
cand = "Welcome to PaddleNLP."
ref_list = ["Welcome PaddleNLP"]
bleu.add_inst(cand, ref_list)
print(bleu.score()) # 0.7510186074254295
2. Using as an instance of `paddle.metric.Metric`.
.. code-block:: python
# TODO(liujiaqi)
'''
def __init__(self, n_size=4, weights=None):
def __init__(self,
trans_func=None,
vocab=None,
n_size=4,
weights=None,
name="bleu"):
super(BLEU, self).__init__()
if not weights:
weights = [1 / n_size for _ in range(n_size)]
assert len(weights) == n_size, (
"Number of weights and n-gram should be the same, got Number of weights: '%d' and n-gram: '%d'"
% (len(weights), n_size))
self._name = name
self.match_ngram = {}
self.candi_ngram = {}
self.weights = weights
self.bp_r = 0
self.bp_c = 0
self.n_size = n_size
self.vocab = vocab
self.trans_func = trans_func
def update(self, output, label, seq_mask=None):
if self.trans_func is None:
if self.vocab is None:
raise AttributeError(
"The `update` method requires users to provide `trans_func` or `vocab` when initializing BLEU."
)
cand_list, ref_list = default_trans_func(output, label, seq_mask,
self.vocab)
else:
cand_list, ref_list = self.trans_func(output, label, seq_mask)
if len(cand_list) != len(ref_list):
raise ValueError(
"Length error! Please check the output of network.")
for i in range(len(cand_list)):
self.add_inst(cand_list[i], ref_list[i])
def add_inst(self, cand, ref_list):
'''
......@@ -117,7 +181,13 @@ class BLEU(object):
self.bp_r += min([(abs(len(cand) - len(ref)), len(ref))
for ref in ref_list])[1]
def score(self):
def reset(self):
self.match_ngram = {}
self.candi_ngram = {}
self.bp_r = 0
self.bp_c = 0
def accumulate(self):
'''
Calculate the final bleu metric.
'''
......@@ -132,16 +202,21 @@ class BLEU(object):
except:
_score = 0
if _score == 0:
_score = w_i * math.log(sys.float_info.min)
_score = sys.float_info.min
prob_list.append(_score)
logs = math.fsum(w_i * math.log(p_i)
for w_i, p_i in zip(self.weights, prob_list))
bp = math.exp(min(1 - self.bp_r / float(self.bp_c), 0))
bleu = bp * math.exp(logs)
return bleu
def score(self):
return self.accumulate()
def name(self):
return self._name
class BLEUForDuReader(BLEU):
'''
......@@ -161,7 +236,6 @@ class BLEUForDuReader(BLEU):
yn_label=None,
yn_ref=None,
entity_ref=None):
#super(BLEUWithBonus, self).add_inst(cand, ref_list)
BLEU.add_inst(self, cand, ref_list)
if yn_label is not None and yn_ref is not None:
self.add_yn_bonus(cand, ref_list, yn_label, yn_ref)
......
......@@ -14,8 +14,13 @@
import numpy as np
import paddle
from .utils import default_trans_func
class RougeL(object):
__all__ = ['RougeL', 'RougeLForDuReader']
class RougeL(paddle.metric.Metric):
r'''
Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS).
Longest common subsequence problem takes into account sentence level structure
......@@ -34,13 +39,31 @@ class RougeL(object):
Args:
gamma (float): A hyperparameter to decide the weight of recall. Default: 1.2.
Examples:(TODO: liujiaqi)
1. Using as a general evaluation object.
2. Using as an instance of `paddle.metric.Metric`.
'''
def __init__(self, gamma=1.2):
def __init__(self,
trans_func=None,
vocab=None,
gamma=1.2,
name="rouge-l",
*args,
**kwargs):
super(RougeL, self).__init__(*args, **kwargs)
self.gamma = gamma
self.inst_scores = []
self._name = name
self.vocab = vocab
self.trans_func = trans_func
def lcs(self, string, sub):
"""
Calculate the length of longest common subsequence of string and sub.
"""
if len(string) < len(sub):
sub, string = string, sub
lengths = np.zeros((len(string) + 1, len(sub) + 1))
......@@ -78,12 +101,37 @@ class RougeL(object):
score = 0.0
self.inst_scores.append(score)
def score(self):
def update(self, output, label, seq_mask=None):
if self.trans_func is None:
if self.vocab is None:
raise AttributeError(
"The `update` method requires users to provide `trans_func` or `vocab` when initializing RougeL."
)
cand_list, ref_list = default_trans_func(output, label, seq_mask,
self.vocab)
else:
cand_list, ref_list = self.trans_func(output, label, seq_mask)
if len(cand_list) != len(ref_list):
raise ValueError(
"Length error! Please check the output of network.")
for i in range(len(cand_list)):
self.add_inst(cand_list[i], ref_list[i])
def accumulate(self):
'''
Calculate the final rouge-l metric.
'''
return 1. * sum(self.inst_scores) / len(self.inst_scores)
def score(self):
return self.accumulate()
def reset(self):
self.inst_scores = []
def name(self):
return self._name
class RougeLForDuReader(RougeL):
'''
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def default_trans_func(output, label, seq_mask, vocab):
seq_mask = np.expand_dims(seq_mask, axis=2).repeat(output.shape[2], axis=2)
output = output * seq_mask
idx = np.argmax(output, axis=2)
cand, ref_list = [], []
for i in range(idx.shape[0]):
token_list = []
for j in range(idx.shape[1]):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[idx[i][j]])
token_str = " ".join(token_list)
ref_list.append([token_str])
label = np.squeeze(label, axis=2)
for i in range(label.shape[0]):
token_list = []
for j in range(label.shape[1]):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[label[i][j]])
token_str = " ".join(token_list)
cand.append(token_str)
return cand, ref_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册