提交 fdaf1f15 编写于 作者: W WenmuZhou

Extraction get_rotate_crop_image

上级 4a7f7c7d
...@@ -20,40 +20,7 @@ from shapely.geometry import Polygon ...@@ -20,40 +20,7 @@ from shapely.geometry import Polygon
from ppocr.data.imaug.iaa_augment import IaaAugment from ppocr.data.imaug.iaa_augment import IaaAugment
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
from tools.infer.utility import get_rotate_crop_image
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
class CopyPaste(object): class CopyPaste(object):
...@@ -164,6 +131,17 @@ def get_intersection(pD, pG): ...@@ -164,6 +131,17 @@ def get_intersection(pD, pG):
def rotate_bbox(img, text_polys, angle, scale=1): def rotate_bbox(img, text_polys, angle, scale=1):
"""
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
Args:
img: np.ndarray
text_polys: np.ndarray N*4*2
angle: int
scale: int
Returns:
"""
w = img.shape[1] w = img.shape[1]
h = img.shape[0] h = img.shape[0]
......
...@@ -31,7 +31,7 @@ import tools.infer.predict_det as predict_det ...@@ -31,7 +31,7 @@ import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb, get_rotate_crop_image
import tools.infer.benchmark_utils as benchmark_utils import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger() logger = get_logger()
...@@ -45,39 +45,6 @@ class TextSystem(object): ...@@ -45,39 +45,6 @@ class TextSystem(object):
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args) self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def print_draw_crop_rec_res(self, img_crop_list, rec_res): def print_draw_crop_rec_res(self, img_crop_list, rec_res):
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
...@@ -89,7 +56,6 @@ class TextSystem(object): ...@@ -89,7 +56,6 @@ class TextSystem(object):
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.info("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
...@@ -99,7 +65,7 @@ class TextSystem(object): ...@@ -99,7 +65,7 @@ class TextSystem(object):
for bno in range(len(dt_boxes)): for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls and cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
......
...@@ -109,11 +109,10 @@ def init_args(): ...@@ -109,11 +109,10 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=bool, default=False) parser.add_argument("--benchmark", type=bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--save_log_path", type=str, default="./log_output/")
return parser return parser
...@@ -615,5 +614,39 @@ def get_current_memory_mb(gpu_id=None): ...@@ -615,5 +614,39 @@ def get_current_memory_mb(gpu_id=None):
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
if __name__ == '__main__': if __name__ == '__main__':
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册