提交 44840726 编写于 作者: W WenmuZhou

后处理添加类型判断

上级 4402e629
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import cv2 import cv2
import paddle
from shapely.geometry import Polygon from shapely.geometry import Polygon
import pyclipper import pyclipper
...@@ -130,7 +131,9 @@ class DBPostProcess(object): ...@@ -130,7 +131,9 @@ class DBPostProcess(object):
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, pred, shape_list): def __call__(self, pred, shape_list):
pred = pred.numpy()[:, 0, :, :] if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh segmentation = pred > self.thresh
boxes_batch = [] boxes_batch = []
...@@ -140,4 +143,4 @@ class DBPostProcess(object): ...@@ -140,4 +143,4 @@ class DBPostProcess(object):
pred[batch_index], segmentation[batch_index], width, height) pred[batch_index], segmentation[batch_index], width, height)
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
\ No newline at end of file
import cv2 import cv2
import paddle
import numpy as np import numpy as np
import pyclipper import pyclipper
from shapely.geometry import Polygon from shapely.geometry import Polygon
...@@ -23,7 +24,9 @@ class DBPostProcess(): ...@@ -23,7 +24,9 @@ class DBPostProcess():
pred: pred:
binary: text region segmentation map, with shape (N, 1,H, W) binary: text region segmentation map, with shape (N, 1,H, W)
''' '''
pred = pred.numpy()[:, 0, :, :] if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = self.binarize(pred) segmentation = self.binarize(pred)
batch_out = [] batch_out = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
...@@ -130,4 +133,4 @@ class DBPostProcess(): ...@@ -130,4 +133,4 @@ class DBPostProcess():
box[:, 0] = box[:, 0] - xmin box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
\ No newline at end of file
...@@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
# out = self.decode_preds(preds) # out = self.decode_preds(preds)
preds = F.softmax(preds, axis=2).numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
...@@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character return dict_character
def decode_preds(self, preds): def decode_preds(self, preds):
probs = F.softmax(preds, axis=2).numpy() probs_ind = np.argmax(preds, axis=2)
probs_ind = np.argmax(probs, axis=2)
B, N, _ = preds.shape B, N, _ = preds.shape
l = np.ones(B).astype(np.int64) * N l = np.ones(B).astype(np.int64) * N
length = paddle.to_variable(l) length = paddle.to_tensor(l)
out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length) out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length)
batch_res = [ batch_res = [
x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy()) x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy())
] ]
result_list = [] result_list = []
for sample_idx, ind, prob in zip(batch_res, probs_ind, probs): for sample_idx, ind, prob in zip(batch_res, probs_ind, preds):
char_list = [self.character[idx] for idx in sample_idx] char_list = [self.character[idx] for idx in sample_idx]
valid_ind = np.where(ind != 0)[0] valid_ind = np.where(ind != 0)[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
...@@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode):
else: else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \ assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册