From 81f0a83e11cfc107a565c302258afc016bb24623 Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Wed, 27 Oct 2021 22:28:58 +0800 Subject: [PATCH] Update rec_nrtr_loss.py --- ppocr/losses/rec_nrtr_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 76cd36b6..200a6d04 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): log_prb = F.log_softmax(pred, axis=1) non_pad_mask = paddle.not_equal( tgt, paddle.zeros( - tgt.shape, dtype='int32')) + tgt.shape, dtype=tgt.dtype)) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: -- GitLab