未验证 提交 87d72ec6 编写于 作者: S smallv0221 提交者: GitHub

fix bleu bug (#5073)

* fix bleu bug
上级 e02bc11f
......@@ -28,15 +28,15 @@ def get_match_size(cand_ngram, refs_ngram):
for ref_ngram in refs_ngram:
tmp_ref_set = defaultdict(int)
for ngram in ref_ngram:
tmp_ref_set[ngram] += tmp_ref_set.get(ngram, 0) + 1
tmp_ref_set[tuple(ngram)] += 1
for ngram, count in tmp_ref_set.items():
ref_set[ngram] = max(ref_set[ngram], count)
ref_set[tuple(ngram)] = max(ref_set[tuple(ngram)], count)
cand_set = defaultdict(int)
for ngram in cand_ngram:
cand_set[ngram] += 1
cand_set[tuple(ngram)] += 1
match_size = 0
for ngram, count in cand_set.items():
match_size += min(count, ref_set.get(ngram, 0))
match_size += min(count, ref_set.get(tuple(ngram), 0))
cand_size = len(cand_ngram)
return match_size, cand_size
......@@ -101,10 +101,10 @@ class BLEU(paddle.metric.Metric):
.. code-block:: python
from paddlenlp.metrics import BLEU
bleu = BLEU()
cand = "Welcome to PaddleNLP."
ref_list = ["Welcome PaddleNLP"]
cand = ["The","cat","The","cat","on","the","mat"]
ref_list = [["The","cat","is","on","the","mat"],["There","is","a","cat","on","the","mat"]]
bleu.add_inst(cand, ref_list)
print(bleu.score()) # 0.7510186074254295
print(bleu.score()) # 0.4671379777282001
2. Using as an instance of `paddle.metric.Metric`.
......@@ -156,7 +156,7 @@ class BLEU(paddle.metric.Metric):
Update the states based on the a pair of candidate and references.
Args:
cand (str): The candidate sentence generated by model.
cand (list): Tokenized candidate sentence generated by model.
ref_list (list): List of ground truth sentences.
'''
for n_size in range(self.n_size):
......
......@@ -26,8 +26,8 @@ def default_trans_func(output, label, seq_mask, vocab):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[idx[i][j]])
token_str = " ".join(token_list)
ref_list.append([token_str])
ref_list.append([token_list])
label = np.squeeze(label, axis=2)
for i in range(label.shape[0]):
......@@ -36,6 +36,6 @@ def default_trans_func(output, label, seq_mask, vocab):
if seq_mask[i][j][0] == 0:
break
token_list.append(vocab[label[i][j]])
token_str = " ".join(token_list)
cand.append(token_str)
cand.append(token_list)
return cand, ref_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册