diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 9c48b09647527cf718113ea1b5df152ff7befa04..ed81d41a59c131014bb98569c5cef2b1512cfb41 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -81,7 +81,7 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = ( - img.astype('float32') * self.scale - self.mean) / self.std + img.astype('float32') * self.scale - self.mean) / self.std return data @@ -122,6 +122,8 @@ class DetResizeForTest(object): elif 'limit_side_len' in kwargs: self.limit_side_len = kwargs['limit_side_len'] self.limit_type = kwargs.get('limit_type', 'min') + self.pad = kwargs.get('pad', False) + self.pad_size = kwargs.get('pad_size', 480) elif 'resize_long' in kwargs: self.resize_type = 2 self.resize_long = kwargs.get('resize_long', 960) @@ -163,7 +165,7 @@ class DetResizeForTest(object): img, (ratio_h, ratio_w) """ limit_side_len = self.limit_side_len - h, w, _ = img.shape + h, w, c = img.shape # limit the max side if self.limit_type == 'max': @@ -172,6 +174,8 @@ class DetResizeForTest(object): ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w + elif self.pad: + ratio = float(self.pad_size) / max(h, w) else: ratio = 1. else: @@ -197,6 +201,10 @@ class DetResizeForTest(object): sys.exit(0) ratio_h = resize_h / float(h) ratio_w = resize_w / float(w) + if self.limit_type == 'max' and self.pad: + padding_im = np.zeros((self.pad_size, self.pad_size, c), dtype=np.float32) + padding_im[:resize_h, :resize_w, :] = img + img = padding_im return img, [ratio_h, ratio_w] def resize_image_type2(self, img): diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 769ddbe23253ce58e2bccd46ef5074cc2a7d27da..0c149610e1770f0f9a295a973cba4f361610c013 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -49,12 +49,12 @@ class DBPostProcess(object): self.dilation_kernel = None if not use_dilation else np.array( [[1, 1], [1, 1]]) - def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + def boxes_from_bitmap(self, pred, _bitmap, shape): ''' _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} ''' - + dest_height, dest_width, ratio_h, ratio_w = shape bitmap = _bitmap height, width = bitmap.shape @@ -89,9 +89,9 @@ class DBPostProcess(object): box = np.array(box) box[:, 0] = np.clip( - np.round(box[:, 0] / width * dest_width), 0, dest_width) + np.round(box[:, 0] / ratio_w), 0, dest_width) box[:, 1] = np.clip( - np.round(box[:, 1] / height * dest_height), 0, dest_height) + np.round(box[:, 1] / ratio_h), 0, dest_height) boxes.append(box.astype(np.int16)) scores.append(score) return np.array(boxes, dtype=np.int16), scores @@ -175,7 +175,6 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): - src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] if self.dilation_kernel is not None: mask = cv2.dilate( np.array(segmentation[batch_index]).astype(np.uint8), @@ -183,7 +182,7 @@ class DBPostProcess(object): else: mask = segmentation[batch_index] boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, - src_w, src_h) + shape_list[batch_index]) boxes_batch.append({'points': boxes}) return boxes_batch diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index ede8501873207867262c733378465c727776b74a..87306eaeb339b6f9b8cc9b11f5508865efcbfd00 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -38,11 +38,13 @@ logger = get_logger() class OCRSystem(object): def __init__(self, args): + args.det_pad = True + args.det_pad_size = 640 self.text_system = TextSystem(args) self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", threshold=0.5, enable_mkldnn=args.enable_mkldnn, - enforce_cpu=not args.use_gpu,thread_num=args.cpu_threads) + enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads) self.use_angle_cls = args.use_angle_cls self.drop_score = args.drop_score @@ -67,7 +69,6 @@ class OCRSystem(object): res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) return res_list - def save_res(res, save_folder, img_name): excel_save_folder = os.path.join(save_folder, img_name) os.makedirs(excel_save_folder, exist_ok=True) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 59bb49f90abb198933b91f222febad7a416018e8..b21db4c7f7c25c7327fd5fc374c3a9bd91c2db3d 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -41,7 +41,9 @@ class TextDetector(object): pre_process_list = [{ 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type + 'limit_type': args.det_limit_type, + 'pad':args.det_pad, + 'pad_size':args.det_pad_size } }, { 'NormalizeImage': { diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a558f490f941ab0dd940329ff7c82c49b6eb31e7..9fb2e8e5f9ac608ea12fa902aca30b2ea5f03b2f 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -46,6 +46,8 @@ def init_args(): parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_type", type=str, default='max') + parser.add_argument("--det_pad", type=str2bool, default=False) + parser.add_argument("--det_pad_size", type=int, default=640) # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3)