From 345b1510ab41a6d4fe343d65a50193bc47433e1b Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Wed, 27 Oct 2021 14:24:29 +0000 Subject: [PATCH] charry pick nrtr_postprocess and modify data type to adaptive --- ppocr/losses/rec_nrtr_loss.py | 2 +- ppocr/postprocess/rec_postprocess.py | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py index 76cd36b6..200a6d04 100644 --- a/ppocr/losses/rec_nrtr_loss.py +++ b/ppocr/losses/rec_nrtr_loss.py @@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): log_prb = F.log_softmax(pred, axis=1) non_pad_mask = paddle.not_equal( tgt, paddle.zeros( - tgt.shape, dtype='int32')) + tgt.shape, dtype=tgt.dtype)) loss = -(one_hot * log_prb).sum(axis=1) loss = loss.masked_select(non_pad_mask).mean() else: diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 07efd972..c0d8bab5 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode): character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): - if preds.dtype == paddle.int64: - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - if preds[0][0]==2: - preds_idx = preds[:,1:] - else: - preds_idx = preds - if len(preds) == 2: preds_id = preds[0] preds_prob = preds[1] -- GitLab