diff --git a/ppcls/loss/celoss.py b/ppcls/loss/celoss.py index 81d12a611d760b04d6f556145a00ca3f136b92d1..2715dee194a7ffecb327483ecf84eb887c9d80b3 100644 --- a/ppcls/loss/celoss.py +++ b/ppcls/loss/celoss.py @@ -51,6 +51,10 @@ class CELoss(nn.Layer): label = self._labelsmoothing(label, class_num) x = -F.log_softmax(x, axis=-1) loss = paddle.sum(x * label, axis=-1) + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() else: if label.shape[-1] == x.shape[-1]: label = F.softmax(label, axis=-1)