diff --git a/utils/callbacks.py b/utils/callbacks.py index b866781442081876c9127500b38e4f279e3d6f68..20e460c5b9d1cc3b9e7fbed9cae077e1b8c30542 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -149,8 +149,6 @@ class EvalCallback(): top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] top_polys = rbox2poly(top_rboxes) - top_polys[:, 0::2] *= image_shape[1] - top_polys[:, 1::2] *= image_shape[0] top_hbbs = poly2hbb(top_polys) top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] top_hbbs = top_hbbs[top_100] diff --git a/utils/utils_bbox.py b/utils/utils_bbox.py index 6c203dd285d0f91f936aea9962418df20fad81e4..ad9dd34de121557703bc36a5946c8db26aecbaf5 100644 --- a/utils/utils_bbox.py +++ b/utils/utils_bbox.py @@ -1,7 +1,7 @@ import numpy as np import torch import math -from torchvision.ops import nms +from utils.utils_rbox import * from utils.nms_rotated import obb_nms class DecodeBox(): @@ -183,8 +183,41 @@ class DecodeBox(): if output[i] is not None: output[i] = output[i].cpu().numpy() + output[i][:, :5] = self.yolo_correct_boxes(output[i], input_shape, image_shape, letterbox_image) return output + def yolo_correct_boxes(self, output, input_shape, image_shape, letterbox_image): + #-----------------------------------------------------------------# + # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 + #-----------------------------------------------------------------# + box_xy = output[..., 0:2] + box_wh = output[..., 2:4] + angle = output[..., 4:5] + box_yx = box_xy[..., ::-1] + box_hw = box_wh[..., ::-1] + input_shape = np.array(input_shape) + image_shape = np.array(image_shape) + + if letterbox_image: + #-----------------------------------------------------------------# + # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 + # new_shape指的是宽高缩放情况 + #-----------------------------------------------------------------# + new_shape = np.round(image_shape * np.min(input_shape/image_shape)) + offset = (input_shape - new_shape)/2./input_shape + scale = input_shape/new_shape + + box_yx = (box_yx - offset) * scale + box_hw *= scale + + box_xy = box_yx[:, ::-1] + box_hw = box_wh[:, ::-1] + + rboxes = np.concatenate([box_xy, box_wh, angle], axis=-1) + rboxes[:, [0, 2]] *= image_shape[1] + rboxes[:, [1, 3]] *= image_shape[0] + return rboxes + if __name__ == "__main__": import matplotlib.pyplot as plt import numpy as np diff --git a/utils/utils_rbox.py b/utils/utils_rbox.py index 711e2e7f165e2e76bbe6eb3aff11787ebc94939b..f4bd86e522e7707d08c5edb562bca14ce94ca715 100644 --- a/utils/utils_rbox.py +++ b/utils/utils_rbox.py @@ -2,7 +2,7 @@ Author: [egrt] Date: 2023-01-30 19:00:28 LastEditors: [egrt] -LastEditTime: 2023-02-10 14:16:06 +LastEditTime: 2023-02-10 22:44:35 Description: Oriented Bounding Boxes utils ''' @@ -198,18 +198,5 @@ def correct_rboxes(rboxes, image_shape): nh, nw = image_shape polys[:, [0, 2, 4, 6]] *= nw polys[:, [1, 3, 5, 7]] *= nh - rboxes = poly2rbox(polys) - correct_polys = [] - for rbox in rboxes: - xc, yc, h, w, ag = rbox[:5] - wx, wy = -w / 2 * math.sin(ag), w / 2 * math.cos(ag) - hx, hy = h / 2 * math.cos(ag), h / 2 * math.sin(ag) - point_x1, point_y1 = xc - wx - hx, yc - wy - hy - point_x2, point_y2 = xc - wx + hx, yc - wy + hy - point_x3, point_y3 = xc + wx + hx, yc + wy + hy - point_x4, point_y4 = xc + wx - hx, yc + wy - hy - poly = np.array([point_x1, point_y1, point_x2, point_y2, - point_x3, point_y3, point_x4, point_y4]) - correct_polys.append(poly) - correct_polys = np.array(correct_polys) - return correct_polys + + return polys diff --git a/yolo.py b/yolo.py index 0ec6d6456dd9638dac39c54524b8f529a4709061..f5a5da46a31fe4885c641ee75acfaaae4479ef63 100644 --- a/yolo.py +++ b/yolo.py @@ -55,7 +55,7 @@ class YOLO(object): # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 #---------------------------------------------------------------------# - "letterbox_image" : False, + "letterbox_image" : True, #-------------------------------# # 是否使用Cuda # 没有GPU可以设置成False @@ -157,7 +157,7 @@ class YOLO(object): top_label = np.array(results[0][:, 7], dtype = 'int32') top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] - top_polys = correct_rboxes(top_rboxes, image_shape) + top_polys = rbox2poly(top_rboxes) #---------------------------------------------------------# # 设置字体与边框厚度 #---------------------------------------------------------# @@ -377,8 +377,6 @@ class YOLO(object): top_conf = results[0][:, 5] * results[0][:, 6] top_rboxes = results[0][:, :5] top_polys = rbox2poly(top_rboxes) - top_polys[:, 0::2] *= image_shape[1] - top_polys[:, 1::2] *= image_shape[0] top_hbbs = poly2hbb(top_polys) for i, c in list(enumerate(top_label)): predicted_class = self.class_names[int(c)]