提交 44840726 编写于 作者: W WenmuZhou

后处理添加类型判断

上级 4402e629
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import cv2 import cv2
import paddle
from shapely.geometry import Polygon from shapely.geometry import Polygon
import pyclipper import pyclipper
...@@ -130,7 +131,9 @@ class DBPostProcess(object): ...@@ -130,7 +131,9 @@ class DBPostProcess(object):
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, pred, shape_list): def __call__(self, pred, shape_list):
pred = pred.numpy()[:, 0, :, :] if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
boxes_batch = [] boxes_batch = []
......
import cv2 import cv2
import paddle
import numpy as np import numpy as np
import pyclipper import pyclipper
from shapely.geometry import Polygon from shapely.geometry import Polygon
...@@ -23,7 +24,9 @@ class DBPostProcess(): ...@@ -23,7 +24,9 @@ class DBPostProcess():
pred: pred:
binary: text region segmentation map, with shape (N, 1,H, W) binary: text region segmentation map, with shape (N, 1,H, W)
''' '''
pred = pred.numpy()[:, 0, :, :] if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = self.binarize(pred) segmentation = self.binarize(pred)
batch_out = [] batch_out = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
......
...@@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
# out = self.decode_preds(preds) # out = self.decode_preds(preds)
preds = F.softmax(preds, axis=2).numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
...@@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character return dict_character
def decode_preds(self, preds): def decode_preds(self, preds):
probs = F.softmax(preds, axis=2).numpy() probs_ind = np.argmax(preds, axis=2)
probs_ind = np.argmax(probs, axis=2)
B, N, _ = preds.shape B, N, _ = preds.shape
l = np.ones(B).astype(np.int64) * N l = np.ones(B).astype(np.int64) * N
length = paddle.to_variable(l) length = paddle.to_tensor(l)
out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length) out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length)
batch_res = [ batch_res = [
x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy()) x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy())
] ]
result_list = [] result_list = []
for sample_idx, ind, prob in zip(batch_res, probs_ind, probs): for sample_idx, ind, prob in zip(batch_res, probs_ind, preds):
char_list = [self.character[idx] for idx in sample_idx] char_list = [self.character[idx] for idx in sample_idx]
valid_ind = np.where(ind != 0)[0] valid_ind = np.where(ind != 0)[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册