提交 ef156b19 编写于 作者: 张欣-男's avatar 张欣-男

优化tools/infer/predict_system.py代码

上级 a28ef7f0
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '../..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
...@@ -39,6 +39,7 @@ class TextSystem(object): ...@@ -39,6 +39,7 @@ class TextSystem(object):
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0])) left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0])) right = int(np.max(points[:, 0]))
...@@ -47,15 +48,19 @@ class TextSystem(object): ...@@ -47,15 +48,19 @@ class TextSystem(object):
img_crop = img[top:bottom, left:right, :].copy() img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1])) '''
img_crop_height = int(np.linalg.norm(points[0] - points[3])) img_crop_width = int(max(np.linalg.norm(points[0] - points[1]),
pts_std = np.float32([[0, 0], [img_crop_width, 0],\ np.linalg.norm(points[2] - points[3])))
[img_crop_width, img_crop_height], [0, img_crop_height]]) img_crop_height = int(max(np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std) M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective( dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height),
img_crop, borderMode=cv2.BORDER_REPLICATE,
M, (img_crop_width, img_crop_height), flags=cv2.INTER_CUBIC)
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2] dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5: if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img) dst_img = np.rot90(dst_img)
...@@ -106,8 +111,7 @@ def sorted_boxes(dt_boxes): ...@@ -106,8 +111,7 @@ def sorted_boxes(dt_boxes):
return _boxes return _boxes
if __name__ == "__main__": def main(args):
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
...@@ -145,3 +149,7 @@ if __name__ == "__main__": ...@@ -145,3 +149,7 @@ if __name__ == "__main__":
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))
if __name__ == "__main__":
main(utility.parse_args())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册