diff --git a/utils/utils_rbox.py b/utils/utils_rbox.py index 59ba75b95a978229d4dbaea58bde2f9ee49cd816..711e2e7f165e2e76bbe6eb3aff11787ebc94939b 100644 --- a/utils/utils_rbox.py +++ b/utils/utils_rbox.py @@ -2,7 +2,7 @@ Author: [egrt] Date: 2023-01-30 19:00:28 LastEditors: [egrt] -LastEditTime: 2023-02-07 17:15:56 +LastEditTime: 2023-02-10 14:16:06 Description: Oriented Bounding Boxes utils ''' @@ -185,4 +185,31 @@ def get_best_begin_point(coordinates): """ coordinates = list(map(get_best_begin_point_single, coordinates.tolist())) coordinates = np.array(coordinates) - return coordinates \ No newline at end of file + return coordinates + +def correct_rboxes(rboxes, image_shape): + """将polys按比例进行缩放 + Args: + coordinate (ndarray): shape(n, 8). + Returns: + reorder coordinate (ndarray): shape(n, 8). + """ + polys = rbox2poly(rboxes) + nh, nw = image_shape + polys[:, [0, 2, 4, 6]] *= nw + polys[:, [1, 3, 5, 7]] *= nh + rboxes = poly2rbox(polys) + correct_polys = [] + for rbox in rboxes: + xc, yc, h, w, ag = rbox[:5] + wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag) + hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag) + point_x1, point_y1 = xc - wx - hx, yc - wy - hy + point_x2, point_y2 = xc - wx + hx, yc - wy + hy + point_x3, point_y3 = xc + wx + hx, yc + wy + hy + point_x4, point_y4 = xc + wx - hx, yc + wy - hy + poly = np.array([point_x1, point_y1, point_x2, point_y2, + point_x3, point_y3, point_x4, point_y4]) + correct_polys.append(poly) + correct_polys = np.array(correct_polys) + return correct_polys diff --git a/yolo.py b/yolo.py index d538f5827f410c1c17febb009aa1a130b576aa2a..0ec6d6456dd9638dac39c54524b8f529a4709061 100644 --- a/yolo.py +++ b/yolo.py @@ -11,7 +11,7 @@ from nets.yolo import YoloBody from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input, resize_image, show_config) from utils.utils_bbox import DecodeBox -from utils.utils_rbox import rbox2poly, poly2hbb +from utils.utils_rbox import * ''' 训练自己的数据集必看注释! ''' @@ -157,9 +157,7 @@ class YOLO(object): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_polys = rbox2poly(top_rboxes) - top_polys[:, 0::2] *= image_shape[1] - top_polys[:, 1::2] *= image_shape[0] + top_polys = correct_rboxes(top_rboxes, image_shape) #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------#