From b3de457d8081c7550f822c0d5be8620de9c944b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E9=B9=AD=E5=85=88=E7=94=9F?= <766529835@qq.com> Date: Fri, 10 Feb 2023 14:38:01 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/utils_rbox.py | 31 +++++++++++++++++++++++++++++-- yolo.py | 6 ++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/utils/utils_rbox.py b/utils/utils_rbox.py index 59ba75b..711e2e7 100644 --- a/utils/utils_rbox.py +++ b/utils/utils_rbox.py @@ -2,7 +2,7 @@ Author: [egrt] Date: 2023-01-30 19:00:28 LastEditors: [egrt] -LastEditTime: 2023-02-07 17:15:56 +LastEditTime: 2023-02-10 14:16:06 Description: Oriented Bounding Boxes utils ''' @@ -185,4 +185,31 @@ def get_best_begin_point(coordinates): """ coordinates = list(map(get_best_begin_point_single, coordinates.tolist())) coordinates = np.array(coordinates) - return coordinates \ No newline at end of file + return coordinates + +def correct_rboxes(rboxes, image_shape): + """将polys按比例进行缩放 + Args: + coordinate (ndarray): shape(n, 8). + Returns: + reorder coordinate (ndarray): shape(n, 8). + """ + polys = rbox2poly(rboxes) + nh, nw = image_shape + polys[:, [0, 2, 4, 6]] *= nw + polys[:, [1, 3, 5, 7]] *= nh + rboxes = poly2rbox(polys) + correct_polys = [] + for rbox in rboxes: + xc, yc, h, w, ag = rbox[:5] + wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag) + hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag) + point_x1, point_y1 = xc - wx - hx, yc - wy - hy + point_x2, point_y2 = xc - wx + hx, yc - wy + hy + point_x3, point_y3 = xc + wx + hx, yc + wy + hy + point_x4, point_y4 = xc + wx - hx, yc + wy - hy + poly = np.array([point_x1, point_y1, point_x2, point_y2, + point_x3, point_y3, point_x4, point_y4]) + correct_polys.append(poly) + correct_polys = np.array(correct_polys) + return correct_polys diff --git a/yolo.py b/yolo.py index d538f58..0ec6d64 100644 --- a/yolo.py +++ b/yolo.py @@ -11,7 +11,7 @@ from nets.yolo import YoloBody from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input, resize_image, show_config) from utils.utils_bbox import DecodeBox -from utils.utils_rbox import rbox2poly, poly2hbb +from utils.utils_rbox import * ''' 训练自己的数据集必看注释! ''' @@ -157,9 +157,7 @@ class YOLO(object): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_polys = rbox2poly(top_rboxes) - top_polys[:, 0::2] *= image_shape[1] - top_polys[:, 1::2] *= image_shape[0] + top_polys = correct_rboxes(top_rboxes, image_shape) #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------# -- GitLab