From 483e50382627e807d0e1c6adad9438f00945b9cc Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Mon, 10 Oct 2022 12:12:47 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=9A=E8=BF=87=E5=8F=98=E9=87=8F=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E5=88=A4=E6=96=AD=E6=98=AF=E5=90=A6=E6=98=AFvisual?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/losses/rec_rfl_loss.py | 18 +++++++++++------- ppocr/postprocess/rec_postprocess.py | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py index 0921406c..be0f06d9 100644 --- a/ppocr/losses/rec_rfl_loss.py +++ b/ppocr/losses/rec_rfl_loss.py @@ -36,22 +36,26 @@ class RFLLoss(nn.Layer): self.total_loss = {} total_loss = 0.0 + if isinstance(predicts, tuple) or isinstance(predicts, list): + cnt_outputs, seq_outputs = predicts + else: + cnt_outputs, seq_outputs = predicts, None # batch [image, label, length, cnt_label] - if predicts[0] is not None: - cnt_loss = self.cnt_loss(predicts[0], + if cnt_outputs is not None: + cnt_loss = self.cnt_loss(cnt_outputs, paddle.cast(batch[3], paddle.float32)) self.total_loss['cnt_loss'] = cnt_loss total_loss += cnt_loss - if predicts[1] is not None: + if seq_outputs is not None: targets = batch[1].astype("int64") label_lengths = batch[2].astype('int64') - batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[ - 1].shape[1], predicts[1].shape[2] - assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \ + batch_size, num_steps, num_classes = seq_outputs.shape[ + 0], seq_outputs.shape[1], seq_outputs.shape[2] + assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \ "The target's shape and inputs's shape is [N, d] and [N, num_steps]" - inputs = predicts[1][:, :-1, :] + inputs = seq_outputs[:, :-1, :] targets = targets[:, 1:] inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]]) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 74f4e880..59b5254e 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode): return result_list def __call__(self, preds, label=None, *args, **kwargs): - if len(preds) == 2: - cnt_pred, preds = preds - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) + # if seq_outputs is not None: + if isinstance(preds, tuple) or isinstance(preds, list): + cnt_outputs, seq_outputs = preds + if isinstance(seq_outputs, paddle.Tensor): + seq_outputs = seq_outputs.numpy() + preds_idx = seq_outputs.argmax(axis=2) + preds_prob = seq_outputs.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if label is None: @@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode): return text, label else: - cnt_pred = preds - if isinstance(cnt_pred, paddle.Tensor): - cnt_pred = cnt_pred.numpy() + cnt_outputs = preds + if isinstance(cnt_outputs, paddle.Tensor): + cnt_outputs = cnt_outputs.numpy() cnt_length = [] - for lens in cnt_pred: + for lens in cnt_outputs: length = round(np.sum(lens)) cnt_length.append(length) if label is None: -- GitLab