From a6b827ed3384f3213809254916b16f49879e18c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E9=B9=AD=E5=85=88=E7=94=9F?= <766529835@qq.com> Date: Mon, 6 Feb 2023 20:59:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E6=9E=81=E5=B0=8F=E6=A1=86=E8=BF=87?= =?UTF-8?q?=E6=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 4 +- utils/dataloader.py | 62 +++++++++++++++++------------ utils/utils_rbox.py | 96 +++------------------------------------------ 3 files changed, 45 insertions(+), 117 deletions(-) diff --git a/train.py b/train.py index 2c8ea5e..fc315fe 100644 --- a/train.py +++ b/train.py @@ -180,7 +180,7 @@ if __name__ == "__main__": # Adam可以使用相对较小的UnFreeze_Epoch # Unfreeze_batch_size 模型在解冻后的batch_size #------------------------------------------------------------------# - UnFreeze_Epoch = 300 + UnFreeze_Epoch = 100 Unfreeze_batch_size = 4 #------------------------------------------------------------------# # Freeze_Train 是否进行冻结训练 @@ -211,7 +211,7 @@ if __name__ == "__main__": #------------------------------------------------------------------# # lr_decay_type 使用到的学习率下降方式,可选的有step、cos #------------------------------------------------------------------# - lr_decay_type = "cos" + lr_decay_type = "step" #------------------------------------------------------------------# # save_period 多少个epoch保存一次权值 #------------------------------------------------------------------# diff --git a/utils/dataloader.py b/utils/dataloader.py index 5dcdf88..024e91a 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -7,7 +7,7 @@ from PIL import Image, ImageDraw from torch.utils.data.dataset import Dataset from utils.utils import cvtColor, preprocess_input -from utils.utils_rbox import poly_filter, poly2rbox, rbox2poly +from utils.utils_rbox import poly2rbox, rbox2poly class YoloDataset(Dataset): def __init__(self, annotation_lines, input_shape, num_classes, anchors, anchors_mask, epoch_length, \ @@ -63,6 +63,12 @@ class YoloDataset(Dataset): nL = len(rbox) labels_out = np.zeros((nL, 7)) if nL: + #---------------------------------------------------# + # 对真实框进行归一化,调整到0-1之间 + #---------------------------------------------------# + rbox[:, [0, 2]] = rbox[:, [0, 2]] / self.input_shape[1] + rbox[:, [1, 3]] = rbox[:, [1, 3]] / self.input_shape[0] + #---------------------------------------------------# #---------------------------------------------------# # 调整顺序,符合训练的格式 # labels_out中序号为0的部分在collect时处理 @@ -95,7 +101,7 @@ class YoloDataset(Dataset): # 将polygon转换为rbox #------------------------------# rbox = np.zeros((box.shape[0], 6)) - rbox[..., :5] = poly2rbox(box[..., :8], (ih, iw), use_pi=True) + rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True) rbox[..., 5] = box[..., 8] if not random: @@ -118,10 +124,14 @@ class YoloDataset(Dataset): #---------------------------------# if len(rbox)>0: np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/w + dx/w - rbox[:, 1] = rbox[:, 1]*nh/h + dy/h - rbox[:, 2] = rbox[:, 2]*nw/w - rbox[:, 3] = rbox[:, 3]*nh/h + rbox[:, 0] = rbox[:, 0]*nw/iw + dx + rbox[:, 1] = rbox[:, 1]*nh/ih + dy + rbox[:, 2] = rbox[:, 2]*nw/iw + rbox[:, 3] = rbox[:, 3]*nh/ih + keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ + & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ + & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) + rbox = rbox[keep] return image_data, rbox @@ -178,16 +188,20 @@ class YoloDataset(Dataset): #---------------------------------# if len(rbox)>0: np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/w + dx/w - rbox[:, 1] = rbox[:, 1]*nh/h + dy/h - rbox[:, 2] = rbox[:, 2]*nw/w - rbox[:, 3] = rbox[:, 3]*nh/h + rbox[:, 0] = rbox[:, 0]*nw/iw + dx + rbox[:, 1] = rbox[:, 1]*nh/ih + dy + rbox[:, 2] = rbox[:, 2]*nw/iw + rbox[:, 3] = rbox[:, 3]*nh/ih if flip: - rbox[:, 0] = 1 - rbox[:, 0] + rbox[:, 0] = w - rbox[:, 0] rbox[:, 4] *= -1 + keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ + & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ + & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) + rbox = rbox[keep] # 查看旋转框是否正确 # draw = ImageDraw.Draw(image) - # polys = rbox2poly(rbox[..., :5])*w + # polys = rbox2poly(rbox[..., :5]) # for poly in polys: # draw.polygon(xy=list(poly)) # image.show() @@ -239,7 +253,7 @@ class YoloDataset(Dataset): # 将polygon转换为rbox #------------------------------# rbox = np.zeros((box.shape[0], 6)) - rbox[..., :5] = poly2rbox(box[..., :8], (ih, iw), use_pi=True) + rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True) rbox[..., 5] = box[..., 8] #---------------------------------# # 是否翻转图片 @@ -247,7 +261,7 @@ class YoloDataset(Dataset): flip = self.rand()<.5 if flip and len(rbox)>0: image = image.transpose(Image.FLIP_LEFT_RIGHT) - rbox[:, 0] = 1 - rbox[:, 0] + rbox[:, 0] = iw - rbox[:, 0] rbox[:, 4] *= -1 #------------------------------------------# # 对图像进行缩放并且进行长和宽的扭曲 @@ -289,10 +303,14 @@ class YoloDataset(Dataset): #---------------------------------# if len(rbox)>0: np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/w + dx/w - rbox[:, 1] = rbox[:, 1]*nh/h + dy/h - rbox[:, 2] = rbox[:, 2]*nw/w - rbox[:, 3] = rbox[:, 3]*nh/h + rbox[:, 0] = rbox[:, 0]*nw/iw + dx + rbox[:, 1] = rbox[:, 1]*nh/ih + dy + rbox[:, 2] = rbox[:, 2]*nw/iw + rbox[:, 3] = rbox[:, 3]*nh/ih + keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ + & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ + & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) + rbox = rbox[keep] rbox_data = np.zeros((len(rbox),6)) rbox_data[:len(rbox)] = rbox @@ -340,7 +358,7 @@ class YoloDataset(Dataset): # 查看旋转框是否正确 # newImage = Image.fromarray(new_image) # draw = ImageDraw.Draw(newImage) - # polys = rbox2poly(new_rboxes[..., :5])*w + # polys = rbox2poly(new_rboxes[..., :5]) # for poly in polys: # draw.polygon(xy=list(poly)) # newImage.show() @@ -354,12 +372,6 @@ class YoloDataset(Dataset): new_rboxes = rbox_1 else: new_rboxes = np.concatenate([rbox_1, rbox_2], axis=0) - # 查看旋转框是否正确 - draw = ImageDraw.Draw(new_image) - polys = rbox2poly(new_rboxes[..., :5])*640 - for poly in polys: - draw.polygon(xy=list(poly)) - new_image.show() return new_image, new_rboxes diff --git a/utils/utils_rbox.py b/utils/utils_rbox.py index 4613410..7daf91d 100644 --- a/utils/utils_rbox.py +++ b/utils/utils_rbox.py @@ -2,7 +2,7 @@ Author: [egrt] Date: 2023-01-30 19:00:28 LastEditors: [egrt] -LastEditTime: 2023-01-30 19:34:35 +LastEditTime: 2023-02-06 20:34:05 Description: Oriented Bounding Boxes utils ''' @@ -41,7 +41,7 @@ def regular_theta(theta, mode='180', start=-pi/2): theta = theta % cycle return theta + start -def poly2rbox(polys, img_size=(), num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False): +def poly2rbox(polys, num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False): """ Trans poly format to rbox format. Args: @@ -49,7 +49,6 @@ def poly2rbox(polys, img_size=(), num_cls_thata=180, radius=6.0, use_pi=False, u num_cls_thata (int): [1], theta class num radius (float32): [1], window radius for Circular Smooth Label use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180) - Returns: use_gaussian True: rboxes (array): @@ -58,39 +57,19 @@ def poly2rbox(polys, img_size=(), num_cls_thata=180, radius=6.0, use_pi=False, u rboxes (array): (num_gts, [cx cy l s θ]) """ assert polys.shape[-1] == 8 - img_h, img_w = img_size[0], img_size[1] if use_gaussian: csl_labels = [] rboxes = [] for poly in polys: poly = np.float32(poly.reshape(4, 2)) - (x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] # opencv>=4.5.1 若是< -90到0 - angle = -angle # θ ∈ [-90, 0] # 故 rbbox2poly 中 角度再 负 了一次 定义是 ccw 逆时针 - # # 两者的闭集位置进行了调换,所以在边界角度处的转换和非边界角度处的转换越有所不同。 - # if angle >= 90: - # angle = angle - 180 - # else: - # w, h = h, w - # angle = angle -90 + (x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] + angle = -angle # θ ∈ [-90, 0] theta = angle / 180 * pi # 转为pi制 # trans opencv format to longedge format θ ∈ [-pi/2, pi/2] if w != max(w, h): - x = x / img_w - y = y / img_h - w, h = h, w - w = w / img_h - h = h / img_w theta += pi/2 - - else: - w = w / img_w - h = h / img_h - - x = x / img_w - y = y / img_h - theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2) angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180) @@ -104,65 +83,7 @@ def poly2rbox(polys, img_size=(), num_cls_thata=180, radius=6.0, use_pi=False, u if use_gaussian: return np.array(rboxes), np.array(csl_labels) return np.array(rboxes) - - -def poly2rbox_new(polys, num_cls_thata=5,angle_w=36, radius=6.0, use_pi=False, use_gaussian=False): - """ - Trans poly format to rbox format. - Args: - polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4]) - num_cls_thata (int): [1], theta class num - radius (float32): [1], window radius for Circular Smooth Label - use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180) - - Returns: - use_gaussian True: - rboxes (array): - csl_labels (array): (num_gts, num_cls_thata) - elif - rboxes (array): (num_gts, [cx cy l s θ]) - """ - assert polys.shape[-1] == 8 - if use_gaussian: - csl_labels = [] - rboxes = [] - for poly in polys: - poly = np.float32(poly.reshape(4, 2)) - (x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] # opencv>=4.5.1 若是< -90到0 - angle = -angle # θ ∈ [-90, 0] # 故 rbbox2poly 中 角度再 负 了一次 - # # 两者的闭集位置进行了调换,所以在边界角度处的转换和非边界角度处的转换越有所不同。 - # if angle >= 90: - # angle = angle - 180 - # else: - # w, h = h, w - # angle = angle -90 - theta = angle / 180 * pi # 转为pi制 - - # trans opencv format to longedge format θ ∈ [-pi/2, pi/2] - if w != max(w, h): - w, h = h, w - theta += pi/2 - theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2) - # while not pi / 2 > theta >= -pi / 2: - # if theta >= pi / 2: - # theta -= pi - # else: - # theta += pi - angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180) - - if not use_pi: # 采用angle弧度制 θ ∈ [0, 180) - rboxes.append([x, y, w, h, angle]) - else: # 采用pi制 - rboxes.append([x, y, w, h, theta]) - if use_gaussian: - csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius) - csl_labels.append(csl_label) - if use_gaussian: - return np.array(rboxes), np.array(csl_labels) - return np.array(rboxes) - - - + def rbox2poly(obboxes): """ Trans rbox format to poly format. @@ -258,9 +179,4 @@ def poly_filter(polys, h, w): y_min = np.amin(y, axis=1) x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num) keep_masks = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h) - return keep_masks - -if __name__ == "__main__": - #print(np.pi) - poly = np.array([[204., 197., 273., 154., 290., 170., 217., 218.]]) - print(poly2rbox_new(poly,use_pi=True)) \ No newline at end of file + return keep_masks \ No newline at end of file -- GitLab