# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle from .utils import default_trans_func __all__ = ['RougeL', 'RougeLForDuReader'] class RougeN(): def __init__(self, n): self.n = n def _get_ngrams(self, words): """Calculates word n-grams for multiple sentences. """ ngram_set = set() max_index_ngram_start = len(words) - self.n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(words[i:i + self.n])) return ngram_set def score(self, evaluated_sentences_ids, reference_sentences_ids): overlapping_count, reference_count = self.compute( evaluated_sentences_ids, reference_sentences_ids) return overlapping_count / reference_count def compute(self, evaluated_sentences_ids, reference_sentences_ids): """ Args: evaluated_sentences (list): the sentences ids predicted by the model. reference_sentences (list): the referenced sentences ids. Its size should be same as evaluated_sentences. Returns: overlapping_count (int): the overlapping n-gram count. reference_count (int): the reference sentences n-gram count. """ if len(evaluated_sentences_ids) <= 0 or len( reference_sentences_ids) <= 0: raise ValueError("Collections must contain at least 1 sentence.") reference_count = 0 overlapping_count = 0 for evaluated_sentence_ids, reference_sentence_ids in zip( evaluated_sentences_ids, reference_sentences_ids): evaluated_ngrams = self._get_ngrams(evaluated_sentence_ids) reference_ngrams = self._get_ngrams(reference_sentence_ids) reference_count += len(reference_ngrams) # Gets the overlapping ngrams between evaluated and reference overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) overlapping_count += len(overlapping_ngrams) return overlapping_count, reference_count def accumulate(self): """ This function returns the mean precision, recall and f1 score for all accumulated minibatches. Returns: float: mean precision, recall and f1 score. """ rouge_score = self.overlapping_count / self.reference_count return rouge_score def reset(self): """ Reset function empties the evaluation memory for previous mini-batches. """ self.overlapping_count = 0 self.reference_count = 0 def name(self): """ Return name of metric instance. """ return "Rouge-%s" % self.n def update(self, overlapping_count, reference_count): """ Args: """ self.overlapping_count += overlapping_count self.reference_count += reference_count class Rouge1(RougeN): def __init__(self): super(Rouge1, self).__init__(n=1) class Rouge2(RougeN): def __init__(self): super(Rouge2, self).__init__(n=2) class RougeL(paddle.metric.Metric): r''' Rouge-L is Recall-Oriented Understudy for Gisting Evaluation based on Longest Common Subsequence (LCS). Longest common subsequence problem takes into account sentence level structure similarity naturally and identifies longest co-occurring in sequence n-grams automatically. .. math:: R_{LCS} & = \frac{LCS(C,S)}{len(S)} P_{LCS} & = \frac{LCS(C,S)}{len(C)} F_{LCS} & = \frac{(1 + \gamma^2)R_{LCS}P_{LCS}}}{R_{LCS} + \gamma^2{R_{LCS}} where `C` is the candidate sentence, and 'S' is the refrence sentence. Args: gamma (float): A hyperparameter to decide the weight of recall. Default: 1.2. Examples:(TODO: liujiaqi) 1. Using as a general evaluation object. 2. Using as an instance of `paddle.metric.Metric`. ''' def __init__(self, trans_func=None, vocab=None, gamma=1.2, name="rouge-l", *args, **kwargs): super(RougeL, self).__init__(*args, **kwargs) self.gamma = gamma self.inst_scores = [] self._name = name self.vocab = vocab self.trans_func = trans_func def lcs(self, string, sub): """ Calculate the length of longest common subsequence of string and sub. """ if len(string) < len(sub): sub, string = string, sub lengths = np.zeros((len(string) + 1, len(sub) + 1)) for j in range(1, len(sub) + 1): for i in range(1, len(string) + 1): if string[i - 1] == sub[j - 1]: lengths[i][j] = lengths[i - 1][j - 1] + 1 else: lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) return lengths[len(string)][len(sub)] def add_inst(self, cand, ref_list): ''' Update the states based on the a pair of candidate and references. Args: cand (str): The candidate sentence generated by model. ref_list (list): List of ground truth sentences. ''' precs, recalls = [], [] for ref in ref_list: basic_lcs = self.lcs(cand, ref) prec = basic_lcs / len(cand) if len(cand) > 0. else 0. rec = basic_lcs / len(ref) if len(ref) > 0. else 0. precs.append(prec) recalls.append(rec) prec_max = max(precs) rec_max = max(recalls) if prec_max != 0 and rec_max != 0: score = ((1 + self.gamma**2) * prec_max * rec_max) / \ float(rec_max + self.gamma**2 * prec_max) else: score = 0.0 self.inst_scores.append(score) def update(self, output, label, seq_mask=None): if self.trans_func is None: if self.vocab is None: raise AttributeError( "The `update` method requires users to provide `trans_func` or `vocab` when initializing RougeL." ) cand_list, ref_list = default_trans_func(output, label, seq_mask, self.vocab) else: cand_list, ref_list = self.trans_func(output, label, seq_mask) if len(cand_list) != len(ref_list): raise ValueError( "Length error! Please check the output of network.") for i in range(len(cand_list)): self.add_inst(cand_list[i], ref_list[i]) def accumulate(self): ''' Calculate the final rouge-l metric. ''' return 1. * sum(self.inst_scores) / len(self.inst_scores) def score(self): return self.accumulate() def reset(self): self.inst_scores = [] def name(self): return self._name class RougeLForDuReader(RougeL): ''' Rouge-L metric with bonus for DuReader contest. Please refer to `DuReader Homepage`_ for more details. ''' def __init__(self, alpha=1.0, beta=1.0, gamma=1.2): super(RougeLForDuReader, self).__init__(gamma) self.alpha = alpha self.beta = beta def add_inst(self, cand, ref_list, yn_label=None, yn_ref=None, entity_ref=None): precs, recalls = [], [] for i, ref in enumerate(ref_list): basic_lcs = self.lcs(cand, ref) yn_bonus, entity_bonus = 0.0, 0.0 if yn_ref is not None and yn_label is not None: yn_bonus = self.add_yn_bonus(cand, ref, yn_label, yn_ref[i]) elif entity_ref is not None: entity_bonus = self.add_entity_bonus(cand, entity_ref) p_denom = len( cand) + self.alpha * yn_bonus + self.beta * entity_bonus r_denom = len( ref) + self.alpha * yn_bonus + self.beta * entity_bonus prec = (basic_lcs + self.alpha * yn_bonus + self.beta * entity_bonus) \ / p_denom if p_denom > 0. else 0. rec = (basic_lcs + self.alpha * yn_bonus + self.beta * entity_bonus) \ / r_denom if r_denom > 0. else 0. precs.append(prec) recalls.append(rec) prec_max = max(precs) rec_max = max(recalls) if prec_max != 0 and rec_max != 0: score = ((1 + self.gamma**2) * prec_max * rec_max) / \ float(rec_max + self.gamma**2 * prec_max) else: score = 0.0 self.inst_scores.append(score) def add_yn_bonus(self, cand, ref, yn_label, yn_ref): if yn_label != yn_ref: return 0.0 lcs_ = self.lcs(cand, ref) return lcs_ def add_entity_bonus(self, cand, entity_ref): lcs_ = 0.0 for ent in entity_ref: if ent in cand: lcs_ += len(ent) return lcs_