未验证 提交 87d72ec6 编写于 作者: S smallv0221 提交者: GitHub

fix bleu bug (#5073)

* fix bleu bug
上级 e02bc11f
...@@ -28,15 +28,15 @@ def get_match_size(cand_ngram, refs_ngram): ...@@ -28,15 +28,15 @@ def get_match_size(cand_ngram, refs_ngram):
for ref_ngram in refs_ngram: for ref_ngram in refs_ngram:
tmp_ref_set = defaultdict(int) tmp_ref_set = defaultdict(int)
for ngram in ref_ngram: for ngram in ref_ngram:
tmp_ref_set[ngram] += tmp_ref_set.get(ngram, 0) + 1 tmp_ref_set[tuple(ngram)] += 1
for ngram, count in tmp_ref_set.items(): for ngram, count in tmp_ref_set.items():
ref_set[ngram] = max(ref_set[ngram], count) ref_set[tuple(ngram)] = max(ref_set[tuple(ngram)], count)
cand_set = defaultdict(int) cand_set = defaultdict(int)
for ngram in cand_ngram: for ngram in cand_ngram:
cand_set[ngram] += 1 cand_set[tuple(ngram)] += 1
match_size = 0 match_size = 0
for ngram, count in cand_set.items(): for ngram, count in cand_set.items():
match_size += min(count, ref_set.get(ngram, 0)) match_size += min(count, ref_set.get(tuple(ngram), 0))
cand_size = len(cand_ngram) cand_size = len(cand_ngram)
return match_size, cand_size return match_size, cand_size
...@@ -101,10 +101,10 @@ class BLEU(paddle.metric.Metric): ...@@ -101,10 +101,10 @@ class BLEU(paddle.metric.Metric):
.. code-block:: python .. code-block:: python
from paddlenlp.metrics import BLEU from paddlenlp.metrics import BLEU
bleu = BLEU() bleu = BLEU()
cand = "Welcome to PaddleNLP." cand = ["The","cat","The","cat","on","the","mat"]
ref_list = ["Welcome PaddleNLP"] ref_list = [["The","cat","is","on","the","mat"],["There","is","a","cat","on","the","mat"]]
bleu.add_inst(cand, ref_list) bleu.add_inst(cand, ref_list)
print(bleu.score()) # 0.7510186074254295 print(bleu.score()) # 0.4671379777282001
2. Using as an instance of `paddle.metric.Metric`. 2. Using as an instance of `paddle.metric.Metric`.
...@@ -156,7 +156,7 @@ class BLEU(paddle.metric.Metric): ...@@ -156,7 +156,7 @@ class BLEU(paddle.metric.Metric):
Update the states based on the a pair of candidate and references. Update the states based on the a pair of candidate and references.
Args: Args:
cand (str): The candidate sentence generated by model. cand (list): Tokenized candidate sentence generated by model.
ref_list (list): List of ground truth sentences. ref_list (list): List of ground truth sentences.
''' '''
for n_size in range(self.n_size): for n_size in range(self.n_size):
......
...@@ -26,8 +26,8 @@ def default_trans_func(output, label, seq_mask, vocab): ...@@ -26,8 +26,8 @@ def default_trans_func(output, label, seq_mask, vocab):
if seq_mask[i][j][0] == 0: if seq_mask[i][j][0] == 0:
break break
token_list.append(vocab[idx[i][j]]) token_list.append(vocab[idx[i][j]])
token_str = " ".join(token_list)
ref_list.append([token_str]) ref_list.append([token_list])
label = np.squeeze(label, axis=2) label = np.squeeze(label, axis=2)
for i in range(label.shape[0]): for i in range(label.shape[0]):
...@@ -36,6 +36,6 @@ def default_trans_func(output, label, seq_mask, vocab): ...@@ -36,6 +36,6 @@ def default_trans_func(output, label, seq_mask, vocab):
if seq_mask[i][j][0] == 0: if seq_mask[i][j][0] == 0:
break break
token_list.append(vocab[label[i][j]]) token_list.append(vocab[label[i][j]])
token_str = " ".join(token_list)
cand.append(token_str) cand.append(token_list)
return cand, ref_list return cand, ref_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册