提交 dec76eb7 编写于 作者: W WenmuZhou

add pad for small image in det

上级 48eba028
...@@ -81,7 +81,7 @@ class NormalizeImage(object): ...@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img, assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage" np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = ( data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std img.astype('float32') * self.scale - self.mean) / self.std
return data return data
...@@ -122,6 +122,8 @@ class DetResizeForTest(object): ...@@ -122,6 +122,8 @@ class DetResizeForTest(object):
elif 'limit_side_len' in kwargs: elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len'] self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min') self.limit_type = kwargs.get('limit_type', 'min')
self.pad = kwargs.get('pad', False)
self.pad_size = kwargs.get('pad_size', 480)
elif 'resize_long' in kwargs: elif 'resize_long' in kwargs:
self.resize_type = 2 self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960) self.resize_long = kwargs.get('resize_long', 960)
...@@ -163,7 +165,7 @@ class DetResizeForTest(object): ...@@ -163,7 +165,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w) img, (ratio_h, ratio_w)
""" """
limit_side_len = self.limit_side_len limit_side_len = self.limit_side_len
h, w, _ = img.shape h, w, c = img.shape
# limit the max side # limit the max side
if self.limit_type == 'max': if self.limit_type == 'max':
...@@ -172,6 +174,8 @@ class DetResizeForTest(object): ...@@ -172,6 +174,8 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / h ratio = float(limit_side_len) / h
else: else:
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
elif self.pad:
ratio = float(self.pad_size) / max(h, w)
else: else:
ratio = 1. ratio = 1.
else: else:
...@@ -197,6 +201,10 @@ class DetResizeForTest(object): ...@@ -197,6 +201,10 @@ class DetResizeForTest(object):
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
if self.limit_type == 'max' and self.pad:
padding_im = np.zeros((self.pad_size, self.pad_size, c), dtype=np.float32)
padding_im[:resize_h, :resize_w, :] = img
img = padding_im
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
def resize_image_type2(self, img): def resize_image_type2(self, img):
......
...@@ -49,12 +49,12 @@ class DBPostProcess(object): ...@@ -49,12 +49,12 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, shape):
''' '''
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1} whose values are binarized as {0, 1}
''' '''
dest_height, dest_width, ratio_h, ratio_w = shape
bitmap = _bitmap bitmap = _bitmap
height, width = bitmap.shape height, width = bitmap.shape
...@@ -89,9 +89,9 @@ class DBPostProcess(object): ...@@ -89,9 +89,9 @@ class DBPostProcess(object):
box = np.array(box) box = np.array(box)
box[:, 0] = np.clip( box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width) np.round(box[:, 0] / ratio_w), 0, dest_width)
box[:, 1] = np.clip( box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height) np.round(box[:, 1] / ratio_h), 0, dest_height)
boxes.append(box.astype(np.int16)) boxes.append(box.astype(np.int16))
scores.append(score) scores.append(score)
return np.array(boxes, dtype=np.int16), scores return np.array(boxes, dtype=np.int16), scores
...@@ -175,7 +175,6 @@ class DBPostProcess(object): ...@@ -175,7 +175,6 @@ class DBPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None: if self.dilation_kernel is not None:
mask = cv2.dilate( mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8), np.array(segmentation[batch_index]).astype(np.uint8),
...@@ -183,7 +182,7 @@ class DBPostProcess(object): ...@@ -183,7 +182,7 @@ class DBPostProcess(object):
else: else:
mask = segmentation[batch_index] mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h) shape_list[batch_index])
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -38,11 +38,13 @@ logger = get_logger() ...@@ -38,11 +38,13 @@ logger = get_logger()
class OCRSystem(object): class OCRSystem(object):
def __init__(self, args): def __init__(self, args):
args.det_pad = True
args.det_pad_size = 640
self.text_system = TextSystem(args) self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn, threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,thread_num=args.cpu_threads) enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score self.drop_score = args.drop_score
...@@ -67,7 +69,6 @@ class OCRSystem(object): ...@@ -67,7 +69,6 @@ class OCRSystem(object):
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list return res_list
def save_res(res, save_folder, img_name): def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name) excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True) os.makedirs(excel_save_folder, exist_ok=True)
......
...@@ -41,7 +41,9 @@ class TextDetector(object): ...@@ -41,7 +41,9 @@ class TextDetector(object):
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type 'limit_type': args.det_limit_type,
'pad':args.det_pad,
'pad_size':args.det_pad_size
} }
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
......
...@@ -46,6 +46,8 @@ def init_args(): ...@@ -46,6 +46,8 @@ def init_args():
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max') parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_pad", type=str2bool, default=False)
parser.add_argument("--det_pad_size", type=int, default=640)
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册