未验证 提交 a2721553 编写于 作者: L LiuChiachi 提交者: GitHub

Update metric bleu and rouge-l, inherited from Metric (#4982)

* update rouge-l, move default_trans_func to utils

* fix error info bugs

* add score method for metric bleu and rouge-l

* update bleu doc
上级 d3029c01
...@@ -14,5 +14,6 @@ ...@@ -14,5 +14,6 @@
from .perplexity import Perplexity from .perplexity import Perplexity
from .chunk import ChunkEvaluator from .chunk import ChunkEvaluator
from .bleu import BLEU from .bleu import BLEU, BLEUForDuReader
from .rouge import RougeL, RougeLForDuReader
from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman from .glue import AccuracyAndF1, Mcc, PearsonAndSpearman
...@@ -16,6 +16,12 @@ import math ...@@ -16,6 +16,12 @@ import math
import sys import sys
from collections import defaultdict from collections import defaultdict
import paddle
from .utils import default_trans_func
__all__ = ["BLEU", "BLEUForDuReader"]
def get_match_size(cand_ngram, refs_ngram): def get_match_size(cand_ngram, refs_ngram):
ref_set = defaultdict(int) ref_set = defaultdict(int)
...@@ -48,12 +54,22 @@ def get_ngram(sent, n_size, label=None): ...@@ -48,12 +54,22 @@ def get_ngram(sent, n_size, label=None):
return ngram_list return ngram_list
class BLEU(object): class BLEU(paddle.metric.Metric):
r''' r'''
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of BLEU (bilingual evaluation understudy) is an algorithm for evaluating the
text which has been machine-translated from one natural language to another. This metric quality of text which has been machine-translated from one natural language
uses a modified form of precision to compare a candidate translation against multiple to another. This metric uses a modified form of precision to compare a
reference translations. candidate translation against multiple reference translations.
BLEU could be used as `paddle.metric.Metric` class, or an ordinary
class. When BLEU is used as `paddle.metric.Metric` class. A function is
needed that transforms the network output to reference string list, and
transforms the label to candidate string. By default, a default function
`default_trans_func` is provided, which gets target sequence id by
calculating the maximum probability of each step. In this case, user must
provide `vocab`. It should be noted that the BLEU here is different from
the BLEU calculated in prediction, and it is only for observation during
training and evaluation.
.. math:: .. math::
...@@ -68,24 +84,72 @@ class BLEU(object): ...@@ -68,24 +84,72 @@ class BLEU(object):
where `c` is the length of candidate sentence, and 'r' is the length of refrence sentence. where `c` is the length of candidate sentence, and 'r' is the length of refrence sentence.
Args: Args:
n_size (int): Number of gram for BLEU metric. Default: 4. trans_func (callable, optional): `trans_func` transforms the network
weights (list, optional): The weights of precision of each gram. Default: None. output to string to calculate.
vocab (dict|paddlenlp.data.vocab, optional): Vocab for target language.
If `trans_func` is None and BLEU is used as `paddle.metric.Metric`
instance, `default_trans_func` will be performed and `vocab` must
be provided.
n_size (int, optional): Number of gram for BLEU metric. Default: 4.
weights (list, optional): The weights of precision of each gram.
Default: None.
name (str, optional): Name of `paddle.metric.Metric` instance.
Default: "bleu".
Examples:
1. Using as a general evaluation object.
.. code-block:: python
from paddlenlp.metrics import BLEU
bleu = BLEU()
cand = "Welcome to PaddleNLP."
ref_list = ["Welcome PaddleNLP"]
bleu.add_inst(cand, ref_list)
print(bleu.score()) # 0.7510186074254295
2. Using as an instance of `paddle.metric.Metric`.
.. code-block:: python
# TODO(liujiaqi)
''' '''
def __init__(self, n_size=4, weights=None): def __init__(self,
trans_func=None,
vocab=None,
n_size=4,
weights=None,
name="bleu"):
super(BLEU, self).__init__()
if not weights: if not weights:
weights = [1 / n_size for _ in range(n_size)] weights = [1 / n_size for _ in range(n_size)]
assert len(weights) == n_size, ( assert len(weights) == n_size, (
"Number of weights and n-gram should be the same, got Number of weights: '%d' and n-gram: '%d'" "Number of weights and n-gram should be the same, got Number of weights: '%d' and n-gram: '%d'"
% (len(weights), n_size)) % (len(weights), n_size))
self._name = name
self.match_ngram = {} self.match_ngram = {}
self.candi_ngram = {} self.candi_ngram = {}
self.weights = weights self.weights = weights
self.bp_r = 0 self.bp_r = 0
self.bp_c = 0 self.bp_c = 0
self.n_size = n_size self.n_size = n_size
self.vocab = vocab
self.trans_func = trans_func
def update(self, output, label, seq_mask=None):
if self.trans_func is None:
if self.vocab is None:
raise AttributeError(
"The `update` method requires users to provide `trans_func` or `vocab` when initializing BLEU."
)
cand_list, ref_list = default_trans_func(output, label, seq_mask,
self.vocab)
else:
cand_list, ref_list = self.trans_func(output, label, seq_mask)
if len(cand_list) != len(ref_list):
raise ValueError(
"Length error! Please check the output of network.")
for i in range(len(cand_list)):
self.add_inst(cand_list[i], ref_list[i])
def add_inst(self, cand, ref_list): def add_inst(self, cand, ref_list):
''' '''
...@@ -117,7 +181,13 @@ class BLEU(object): ...@@ -117,7 +181,13 @@ class BLEU(object):
self.bp_r += min([(abs(len(cand) - len(ref)), len(ref)) self.bp_r += min([(abs(len(cand) - len(ref)), len(ref))
for ref in ref_list])[1] for ref in ref_list])[1]
def score(self): def reset(self):
self.match_ngram = {}
self.candi_ngram = {}
self.bp_r = 0
self.bp_c = 0
def accumulate(self):
''' '''
Calculate the final bleu metric. Calculate the final bleu metric.
''' '''
...@@ -132,16 +202,21 @@ class BLEU(object): ...@@ -132,16 +202,21 @@ class BLEU(object):
except: except:
_score = 0 _score = 0
if _score == 0: if _score == 0:
_score = w_i * math.log(sys.float_info.min) _score = sys.float_info.min
prob_list.append(_score) prob_list.append(_score)
logs = math.fsum(w_i * math.log(p_i) logs = math.fsum(w_i * math.log(p_i)
for w_i, p_i in zip(self.weights, prob_list)) for w_i, p_i in zip(self.weights, prob_list))
bp = math.exp(min(1 - self.bp_r / float(self.bp_c), 0)) bp = math.exp(min(1 - self.bp_r / float(self.bp_c), 0))
bleu = bp * math.exp(logs) bleu = bp * math.exp(logs)
return bleu return bleu
def score(self):
return self.accumulate()
def name(self):
return self._name
class BLEUForDuReader(BLEU): class BLEUForDuReader(BLEU):
''' '''
...@@ -161,7 +236,6 @@ class BLEUForDuReader(BLEU): ...@@ -161,7 +236,6 @@ class BLEUForDuReader(BLEU):
yn_label=None, yn_label=None,
yn_ref=None, yn_ref=None,
entity_ref=None): entity_ref=None):
#super(BLEUWithBonus, self).add_inst(cand, ref_list)
BLEU.add_inst(self, cand, ref_list) BLEU.add_inst(self, cand, ref_list)
if yn_label is not None and yn_ref is not None: if yn_label is not None and yn_ref is not None:
self.add_yn_bonus(cand, ref_list, yn_label, yn_ref) self.add_yn_bonus(cand, ref_list, yn_label, yn_ref)
......
...@@ -14,8 +14,13 @@ ...@@ -14,8 +14,13 @@
import numpy as np import numpy as np
import paddle
from .utils import default_trans_func
class RougeL(object): __all__ = ['RougeL', 'RougeLForDuReader']
class RougeL(paddle.metric.Metric):
r''' r'''
Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS). Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS).
Longest common subsequence problem takes into account sentence level structure Longest common subsequence problem takes into account sentence level structure
...@@ -34,13 +39,31 @@ class RougeL(object): ...@@ -34,13 +39,31 @@ class RougeL(object):
Args: Args:
gamma (float): A hyperparameter to decide the weight of recall. Default: 1.2. gamma (float): A hyperparameter to decide the weight of recall. Default: 1.2.
Examples:(TODO: liujiaqi)
1. Using as a general evaluation object.
2. Using as an instance of `paddle.metric.Metric`.
''' '''
def __init__(self, gamma=1.2): def __init__(self,
trans_func=None,
vocab=None,
gamma=1.2,
name="rouge-l",
*args,
**kwargs):
super(RougeL, self).__init__(*args, **kwargs)
self.gamma = gamma self.gamma = gamma
self.inst_scores = [] self.inst_scores = []
self._name = name
self.vocab = vocab
self.trans_func = trans_func
def lcs(self, string, sub): def lcs(self, string, sub):
"""
Calculate the length of longest common subsequence of string and sub.
"""
if len(string) < len(sub): if len(string) < len(sub):
sub, string = string, sub sub, string = string, sub
lengths = np.zeros((len(string) + 1, len(sub) + 1)) lengths = np.zeros((len(string) + 1, len(sub) + 1))
...@@ -78,12 +101,37 @@ class RougeL(object): ...@@ -78,12 +101,37 @@ class RougeL(object):
score = 0.0 score = 0.0
self.inst_scores.append(score) self.inst_scores.append(score)
def score(self): def update(self, output, label, seq_mask=None):
if self.trans_func is None:
if self.vocab is None:
raise AttributeError(
"The `update` method requires users to provide `trans_func` or `vocab` when initializing RougeL."
)
cand_list, ref_list = default_trans_func(output, label, seq_mask,
self.vocab)
else:
cand_list, ref_list = self.trans_func(output, label, seq_mask)
if len(cand_list) != len(ref_list):
raise ValueError(
"Length error! Please check the output of network.")
for i in range(len(cand_list)):
self.add_inst(cand_list[i], ref_list[i])
def accumulate(self):
''' '''
Calculate the final rouge-l metric. Calculate the final rouge-l metric.
''' '''
return 1. * sum(self.inst_scores) / len(self.inst_scores) return 1. * sum(self.inst_scores) / len(self.inst_scores)
def score(self):
return self.accumulate()
def reset(self):
self.inst_scores = []
def name(self):
return self._name
class RougeLForDuReader(RougeL): class RougeLForDuReader(RougeL):
''' '''
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def default_trans_func(output, label, seq_mask, vocab):
seq_mask = np.expand_dims(seq_mask, axis=2).repeat(output.shape[2], axis=2)
output = output * seq_mask
idx = np.argmax(output, axis=2)
cand, ref_list = [], []
for i in range(idx.shape[0]):
token_list = []
for j in range(idx.shape[1]):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[idx[i][j]])
token_str = " ".join(token_list)
ref_list.append([token_str])
label = np.squeeze(label, axis=2)
for i in range(label.shape[0]):
token_list = []
for j in range(label.shape[1]):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[label[i][j]])
token_str = " ".join(token_list)
cand.append(token_str)
return cand, ref_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册