From 87d72ec62f549d235de4d36f48f9983a453b8ce9 Mon Sep 17 00:00:00 2001 From: smallv0221 <33639025+smallv0221@users.noreply.github.com> Date: Wed, 16 Dec 2020 19:13:40 +0800 Subject: [PATCH] fix bleu bug (#5073) * fix bleu bug --- PaddleNLP/paddlenlp/metrics/bleu.py | 16 ++++++++-------- PaddleNLP/paddlenlp/metrics/utils.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/PaddleNLP/paddlenlp/metrics/bleu.py b/PaddleNLP/paddlenlp/metrics/bleu.py index 0cf261ce..bc2aae19 100644 --- a/PaddleNLP/paddlenlp/metrics/bleu.py +++ b/PaddleNLP/paddlenlp/metrics/bleu.py @@ -28,15 +28,15 @@ def get_match_size(cand_ngram, refs_ngram): for ref_ngram in refs_ngram: tmp_ref_set = defaultdict(int) for ngram in ref_ngram: - tmp_ref_set[ngram] += tmp_ref_set.get(ngram, 0) + 1 + tmp_ref_set[tuple(ngram)] += 1 for ngram, count in tmp_ref_set.items(): - ref_set[ngram] = max(ref_set[ngram], count) + ref_set[tuple(ngram)] = max(ref_set[tuple(ngram)], count) cand_set = defaultdict(int) for ngram in cand_ngram: - cand_set[ngram] += 1 + cand_set[tuple(ngram)] += 1 match_size = 0 for ngram, count in cand_set.items(): - match_size += min(count, ref_set.get(ngram, 0)) + match_size += min(count, ref_set.get(tuple(ngram), 0)) cand_size = len(cand_ngram) return match_size, cand_size @@ -101,10 +101,10 @@ class BLEU(paddle.metric.Metric): .. code-block:: python from paddlenlp.metrics import BLEU bleu = BLEU() - cand = "Welcome to PaddleNLP." - ref_list = ["Welcome PaddleNLP"] + cand = ["The","cat","The","cat","on","the","mat"] + ref_list = [["The","cat","is","on","the","mat"],["There","is","a","cat","on","the","mat"]] bleu.add_inst(cand, ref_list) - print(bleu.score()) # 0.7510186074254295 + print(bleu.score()) # 0.4671379777282001 2. Using as an instance of `paddle.metric.Metric`. @@ -156,7 +156,7 @@ class BLEU(paddle.metric.Metric): Update the states based on the a pair of candidate and references. Args: - cand (str): The candidate sentence generated by model. + cand (list): Tokenized candidate sentence generated by model. ref_list (list): List of ground truth sentences. ''' for n_size in range(self.n_size): diff --git a/PaddleNLP/paddlenlp/metrics/utils.py b/PaddleNLP/paddlenlp/metrics/utils.py index d3177e1f..d32cabfc 100644 --- a/PaddleNLP/paddlenlp/metrics/utils.py +++ b/PaddleNLP/paddlenlp/metrics/utils.py @@ -26,8 +26,8 @@ def default_trans_func(output, label, seq_mask, vocab): if seq_mask[i][j][0] == 0: break token_list.append(vocab[idx[i][j]]) - token_str = " ".join(token_list) - ref_list.append([token_str]) + + ref_list.append([token_list]) label = np.squeeze(label, axis=2) for i in range(label.shape[0]): @@ -36,6 +36,6 @@ def default_trans_func(output, label, seq_mask, vocab): if seq_mask[i][j][0] == 0: break token_list.append(vocab[label[i][j]]) - token_str = " ".join(token_list) - cand.append(token_str) + + cand.append(token_list) return cand, ref_list -- GitLab