From 44840726ffb035a359adf1be73b4943d00e192cd Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 9 Nov 2020 18:19:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8E=E5=A4=84=E7=90=86=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/postprocess/db_postprocess.py | 7 +++++-- ppocr/postprocess/db_postprocess_torch.py | 7 +++++-- ppocr/postprocess/rec_postprocess.py | 12 ++++++------ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index f09acb2a..316f7fc2 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -18,6 +18,7 @@ from __future__ import print_function import numpy as np import cv2 +import paddle from shapely.geometry import Polygon import pyclipper @@ -130,7 +131,9 @@ class DBPostProcess(object): return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] def __call__(self, pred, shape_list): - pred = pred.numpy()[:, 0, :, :] + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] segmentation = pred > self.thresh boxes_batch = [] @@ -140,4 +143,4 @@ class DBPostProcess(object): pred[batch_index], segmentation[batch_index], width, height) boxes_batch.append({'points': boxes}) - return boxes_batch + return boxes_batch \ No newline at end of file diff --git a/ppocr/postprocess/db_postprocess_torch.py b/ppocr/postprocess/db_postprocess_torch.py index 83770df0..d1466327 100644 --- a/ppocr/postprocess/db_postprocess_torch.py +++ b/ppocr/postprocess/db_postprocess_torch.py @@ -1,4 +1,5 @@ import cv2 +import paddle import numpy as np import pyclipper from shapely.geometry import Polygon @@ -23,7 +24,9 @@ class DBPostProcess(): pred: binary: text region segmentation map, with shape (N, 1,H, W) ''' - pred = pred.numpy()[:, 0, :, :] + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] segmentation = self.binarize(pred) batch_out = [] for batch_index in range(pred.shape[0]): @@ -130,4 +133,4 @@ class DBPostProcess(): box[:, 0] = box[:, 0] - xmin box[:, 1] = box[:, 1] - ymin cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) - return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] \ No newline at end of file diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 73dcdaae..03208227 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode): character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() # out = self.decode_preds(preds) - preds = F.softmax(preds, axis=2).numpy() preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob) @@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode): return dict_character def decode_preds(self, preds): - probs = F.softmax(preds, axis=2).numpy() - probs_ind = np.argmax(probs, axis=2) + probs_ind = np.argmax(preds, axis=2) B, N, _ = preds.shape l = np.ones(B).astype(np.int64) * N - length = paddle.to_variable(l) + length = paddle.to_tensor(l) out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length) batch_res = [ x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy()) ] result_list = [] - for sample_idx, ind, prob in zip(batch_res, probs_ind, probs): + for sample_idx, ind, prob in zip(batch_res, probs_ind, preds): char_list = [self.character[idx] for idx in sample_idx] valid_ind = np.where(ind != 0)[0] if len(valid_ind) == 0: @@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode): else: assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end - return idx + return idx \ No newline at end of file -- GitLab