# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Metrics class. """ from collections import Counter from nltk.translate import bleu_score from nltk.translate.bleu_score import SmoothingFunction import numpy as np def distinct(seqs): """ Calculate intra/inter distinct 1/2. """ batch_size = len(seqs) intra_dist1, intra_dist2 = [], [] unigrams_all, bigrams_all = Counter(), Counter() for seq in seqs: unigrams = Counter(seq) bigrams = Counter(zip(seq, seq[1:])) intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5)) intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5)) unigrams_all.update(unigrams) bigrams_all.update(bigrams) inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5) inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5) intra_dist1 = np.average(intra_dist1) intra_dist2 = np.average(intra_dist2) return intra_dist1, intra_dist2, inter_dist1, inter_dist2 def bleu(hyps, refs): """ Calculate bleu 1/2. """ bleu_1 = [] bleu_2 = [] for hyp, ref in zip(hyps, refs): try: score = bleu_score.sentence_bleu( [ref], hyp, smoothing_function=SmoothingFunction().method7, weights=[1, 0, 0, 0]) except: score = 0 bleu_1.append(score) try: score = bleu_score.sentence_bleu( [ref], hyp, smoothing_function=SmoothingFunction().method7, weights=[0.5, 0.5, 0, 0]) except: score = 0 bleu_2.append(score) bleu_1 = np.average(bleu_1) bleu_2 = np.average(bleu_2) return bleu_1, bleu_2