diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 41714dd2a3ae15eeedc62521d97935f68271c598..76cd36b67b28063130b23ae362e93cdb168e9248 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): log_prb = F.log_softmax(pred, axis=1) non_pad_mask = paddle.not_equal( tgt, paddle.zeros( - tgt.shape, dtype='int64')) + tgt.shape, dtype='int32')) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: