提交 bcecbd11 编写于 作者: D dongshuilong 提交者: Tingquan Gao

fix celoss bug

上级 a6238512
...@@ -51,6 +51,10 @@ class CELoss(nn.Layer): ...@@ -51,6 +51,10 @@ class CELoss(nn.Layer):
label = self._labelsmoothing(label, class_num) label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1) x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1) loss = paddle.sum(x * label, axis=-1)
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
else: else:
if label.shape[-1] == x.shape[-1]: if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1) label = F.softmax(label, axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册