未验证 提交 4ca78a07 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #241 from ZhangXinNan/zxdev

优化predict_system.py代码
......@@ -13,9 +13,9 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(__file__)
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '../..'))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
......@@ -39,6 +39,7 @@ class TextSystem(object):
self.text_recognizer = predict_rec.TextRecognizer(args)
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
......@@ -47,15 +48,19 @@ class TextSystem(object):
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
[img_crop_width, img_crop_height], [0, img_crop_height]])
'''
img_crop_width = int(max(np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(max(np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0],
[img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
......@@ -106,8 +111,7 @@ def sorted_boxes(dt_boxes):
return _boxes
if __name__ == "__main__":
args = utility.parse_args()
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args)
is_visualize = True
......@@ -145,3 +149,7 @@ if __name__ == "__main__":
draw_img[:, :, ::-1])
print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
if __name__ == "__main__":
main(utility.parse_args())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册