未验证 提交 84a6ea83 编写于 作者: L LiuChiachi 提交者: GitHub

Update perplexity, inherited from Metric (#4995)

* update ppl, inherited from Metric, suitable for common mt and lm task.

* fix ppl bugs, delete useless argument

* fix ppl bugs, return a num not a numpy array.
上级 e59f15a1
...@@ -15,33 +15,62 @@ ...@@ -15,33 +15,62 @@
import paddle import paddle
import numpy as np import numpy as np
import paddle.nn.functional as F
class Perplexity(paddle.metric.Metric): class Perplexity(paddle.metric.Metric):
"""
Perplexity is calculated using cross entropy. It supports both padding data
and no padding data.
If data is not padded, users should provide `seq_len` for `Metric`
initialization. If data is padded, your label should contain `seq_mask`,
which indicates the actual length of samples.
This Perplexity requires that the output of your network is prediction,
label and sequence length (opitonal). If the Perplexity here doesn't meet
your needs, you could override the `compute` or `update` method for
caculating Perplexity.
Args:
seq_len(int): Sequence length of each sample, it must be provided while
data is not padded. Default: 20.
name(str): Name of `Metric` instance. Default: 'Perplexity'.
"""
def __init__(self, name='Perplexity', *args, **kwargs): def __init__(self, name='Perplexity', *args, **kwargs):
super(Perplexity, self).__init__(*args, **kwargs) super(Perplexity, self).__init__(*args, **kwargs)
self._name = name self._name = name
self.total_ce = 0 self.total_ce = 0
self.num_batch = 0 self.total_word_num = 0
def update(self, y, label, *args): def compute(self, pred, label, seq_mask=None):
# Perplexity is calculated using cross entropy
label = paddle.to_tensor(label)
y = paddle.to_tensor(y)
label = paddle.unsqueeze(label, axis=2) label = paddle.unsqueeze(label, axis=2)
ce = F.softmax_with_cross_entropy(
logits=pred, label=label, soft_label=False)
ce = paddle.squeeze(ce, axis=[2])
if seq_mask is not None:
ce = ce * seq_mask
word_num = paddle.sum(seq_mask)
return ce, word_num
return ce
ce = paddle.nn.functional.softmax_with_cross_entropy( def update(self, ce, word_num=None):
logits=y, label=label, soft_label=False) batch_ce = np.sum(ce)
ce = paddle.mean(ce) if word_num is None:
word_num = ce.shape[0] * ce.shape[1]
self.total_ce += ce.numpy()[0] else:
self.num_batch += 1 word_num = word_num[0]
self.total_ce += batch_ce
self.total_word_num += word_num
def reset(self): def reset(self):
self.total_ce = 0 self.total_ce = 0
self.num_batch = 0 self.total_word_num = 0
def accumulate(self): def accumulate(self):
return np.exp(self.total_ce / self.num_batch) return np.exp(self.total_ce / self.total_word_num)
def name(self): def name(self):
return self._name return self._name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册