diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 57f9a7e0a9c27436968219834ac14c815257596d..8d075502ea227369bfe8f804ffd9e1fd64888a74 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,9 +13,9 @@ # limitations under the License. import os import sys -__dir__ = os.path.dirname(__file__) +__dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.join(__dir__, '../..')) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import tools.infer.utility as utility from ppocr.utils.utility import initial_logger @@ -39,6 +39,7 @@ class TextSystem(object): self.text_recognizer = predict_rec.TextRecognizer(args) def get_rotate_crop_image(self, img, points): + ''' img_height, img_width = img.shape[0:2] left = int(np.min(points[:, 0])) right = int(np.max(points[:, 0])) @@ -47,15 +48,19 @@ class TextSystem(object): img_crop = img[top:bottom, left:right, :].copy() points[:, 0] = points[:, 0] - left points[:, 1] = points[:, 1] - top - img_crop_width = int(np.linalg.norm(points[0] - points[1])) - img_crop_height = int(np.linalg.norm(points[0] - points[3])) - pts_std = np.float32([[0, 0], [img_crop_width, 0],\ - [img_crop_width, img_crop_height], [0, img_crop_height]]) + ''' + img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img_crop, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE) + dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) dst_img_height, dst_img_width = dst_img.shape[0:2] if dst_img_height * 1.0 / dst_img_width >= 1.5: dst_img = np.rot90(dst_img) @@ -106,8 +111,7 @@ def sorted_boxes(dt_boxes): return _boxes -if __name__ == "__main__": - args = utility.parse_args() +def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True @@ -145,3 +149,7 @@ if __name__ == "__main__": draw_img[:, :, ::-1]) print("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file)))) + + +if __name__ == "__main__": + main(utility.parse_args())