From bcecbd11599eaf4f49ee01ac25e9731b2da7f6fb Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Wed, 26 Oct 2022 11:50:37 +0000 Subject: [PATCH] fix celoss bug --- ppcls/loss/celoss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ppcls/loss/celoss.py b/ppcls/loss/celoss.py index 81d12a61..2715dee1 100644 --- a/ppcls/loss/celoss.py +++ b/ppcls/loss/celoss.py @@ -51,6 +51,10 @@ class CELoss(nn.Layer): label = self._labelsmoothing(label, class_num) x = -F.log_softmax(x, axis=-1) loss = paddle.sum(x * label, axis=-1) + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() else: if label.shape[-1] == x.shape[-1]: label = F.softmax(label, axis=-1) -- GitLab