From bc39df29c8f91e83258f382840d29e45dfd91209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E9=B9=AD=E5=85=88=E7=94=9F?= <766529835@qq.com> Date: Tue, 7 Feb 2023 18:03:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81hrsc=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hrsc_annotation.py | 161 +++++++++++++++++++++++++++++++ utils/dataloader.py | 110 ++++++++++----------- utils/utils_rbox.py | 230 +++++++++++++++++++++++--------------------- 3 files changed, 335 insertions(+), 166 deletions(-) create mode 100644 hrsc_annotation.py diff --git a/hrsc_annotation.py b/hrsc_annotation.py new file mode 100644 index 0000000..3e13663 --- /dev/null +++ b/hrsc_annotation.py @@ -0,0 +1,161 @@ +import os +import random +import xml.etree.ElementTree as ET + +import numpy as np +from utils.utils_rbox import * +from utils.utils import get_classes + +#--------------------------------------------------------------------------------------------------------------------------------# +# annotation_mode用于指定该文件运行时计算的内容 +# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt +# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt +# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt +#--------------------------------------------------------------------------------------------------------------------------------# +annotation_mode = 0 +#-------------------------------------------------------------------# +# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息 +# 与训练和预测所用的classes_path一致即可 +# 如果生成的2007_train.txt里面没有目标信息 +# 那么就是因为classes没有设定正确 +# 仅在annotation_mode为0和2的时候有效 +#-------------------------------------------------------------------# +classes_path = 'model_data/hrsc_classes.txt' +#--------------------------------------------------------------------------------------------------------------------------------# +# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1 +# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1 +# 仅在annotation_mode为0和1的时候有效 +#--------------------------------------------------------------------------------------------------------------------------------# +trainval_percent = 0.9 +train_percent = 0.9 +#-------------------------------------------------------# +# 指向VOC数据集所在的文件夹 +# 默认指向根目录下的VOC数据集 +#-------------------------------------------------------# +VOCdevkit_path = 'VOCdevkit' + +VOCdevkit_sets = [('2007_HRSC', 'train'), ('2007_HRSC', 'val')] +classes, _ = get_classes(classes_path) + +#-------------------------------------------------------# +# 统计目标数量 +#-------------------------------------------------------# +photo_nums = np.zeros(len(VOCdevkit_sets)) +nums = np.zeros(len(classes)) +def convert_annotation(year, image_id, list_file): + in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8') + tree=ET.parse(in_file) + root = tree.getroot().find('HRSC_Objects') + + for obj in root.iter('HRSC_Object'): + difficult = 0 + if obj.find('difficult')!=None: + difficult = obj.find('difficult').text + cls = obj.find('name').text + if cls not in classes or int(difficult)==1: + continue + if obj.find('mbox_cx')==None: + continue + cls_id = classes.index(cls) + cx = float(obj.find('mbox_cx').text) + cy = float(obj.find('mbox_cy').text) + w = float(obj.find('mbox_w').text) + h = float(obj.find('mbox_h').text) + angle = float(obj.find('mbox_ang').text) + b = np.array([[cx, cy, w, h, angle]], dtype=np.float32) + b = rbox2poly(b)[0] + b = (b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]) + list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) + + nums[classes.index(cls)] = nums[classes.index(cls)] + 1 + +if __name__ == "__main__": + random.seed(0) + if " " in os.path.abspath(VOCdevkit_path): + raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。") + + if annotation_mode == 0 or annotation_mode == 1: + print("Generate txt in ImageSets.") + xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007_HRSC/Annotations') + saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007_HRSC/ImageSets/Main') + temp_xml = os.listdir(xmlfilepath) + total_xml = [] + for xml in temp_xml: + if xml.endswith(".xml"): + total_xml.append(xml) + + num = len(total_xml) + list = range(num) + tv = int(num*trainval_percent) + tr = int(tv*train_percent) + trainval= random.sample(list,tv) + train = random.sample(trainval,tr) + + print("train and val size",tv) + print("train size",tr) + ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') + ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') + ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') + fval = open(os.path.join(saveBasePath,'val.txt'), 'w') + + for i in list: + name=total_xml[i][:-4]+'\n' + if i in trainval: + ftrainval.write(name) + if i in train: + ftrain.write(name) + else: + fval.write(name) + else: + ftest.write(name) + + ftrainval.close() + ftrain.close() + fval.close() + ftest.close() + print("Generate txt in ImageSets done.") + + if annotation_mode == 0 or annotation_mode == 2: + print("Generate 2007_train.txt and 2007_val.txt for train.") + type_index = 0 + for year, image_set in VOCdevkit_sets: + image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split() + list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8') + for image_id in image_ids: + list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id)) + + convert_annotation(year, image_id, list_file) + list_file.write('\n') + photo_nums[type_index] = len(image_ids) + type_index += 1 + list_file.close() + print("Generate 2007_train.txt and 2007_val.txt for train done.") + + def printTable(List1, List2): + for i in range(len(List1[0])): + print("|", end=' ') + for j in range(len(List1)): + print(List1[j][i].rjust(int(List2[j])), end=' ') + print("|", end=' ') + print() + + str_nums = [str(int(x)) for x in nums] + tableData = [ + classes, str_nums + ] + colWidths = [0]*len(tableData) + len1 = 0 + for i in range(len(tableData)): + for j in range(len(tableData[i])): + if len(tableData[i][j]) > colWidths[i]: + colWidths[i] = len(tableData[i][j]) + printTable(tableData, colWidths) + + if photo_nums[0] <= 500: + print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。") + + if np.sum(nums) == 0: + print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") + print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") + print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") + print("(重要的事情说三遍)。") diff --git a/utils/dataloader.py b/utils/dataloader.py index 024e91a..3a0c036 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -81,7 +81,7 @@ class YoloDataset(Dataset): def rand(self, a=0, b=1): return np.random.rand()*(b-a) + a - def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True): + def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True, show=False): line = annotation_line.split() #------------------------------# # 读取图像并转换成RGB图像 @@ -96,13 +96,7 @@ class YoloDataset(Dataset): #------------------------------# # 获得预测框 #------------------------------# - box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) - #------------------------------# - # 将polygon转换为rbox - #------------------------------# - rbox = np.zeros((box.shape[0], 6)) - rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True) - rbox[..., 5] = box[..., 8] + box = np.array([np.array(list(map(float,box.split(',')))) for box in line[1:]]) if not random: scale = min(w/iw, h/ih) @@ -122,17 +116,20 @@ class YoloDataset(Dataset): #---------------------------------# # 对真实框进行调整 #---------------------------------# - if len(rbox)>0: - np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/iw + dx - rbox[:, 1] = rbox[:, 1]*nh/ih + dy - rbox[:, 2] = rbox[:, 2]*nw/iw - rbox[:, 3] = rbox[:, 3]*nh/ih + if len(box)>0: + np.random.shuffle(box) + box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx + box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy + #------------------------------# + # 将polygon转换为rbox + #------------------------------# + rbox = np.zeros((box.shape[0], 6)) + rbox[..., :5] = poly2rbox(box[..., :8]) + rbox[..., 5] = box[..., 8] keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) rbox = rbox[keep] - return image_data, rbox #------------------------------------------# @@ -186,25 +183,30 @@ class YoloDataset(Dataset): #---------------------------------# # 对真实框进行调整 #---------------------------------# - if len(rbox)>0: - np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/iw + dx - rbox[:, 1] = rbox[:, 1]*nh/ih + dy - rbox[:, 2] = rbox[:, 2]*nw/iw - rbox[:, 3] = rbox[:, 3]*nh/ih - if flip: - rbox[:, 0] = w - rbox[:, 0] - rbox[:, 4] *= -1 + if len(box)>0: + np.random.shuffle(box) + box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx + box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy + if flip: box[:, [0,2,4,6]] = w - box[:, [0,2,4,6]] + #------------------------------# + # 将polygon转换为rbox + #------------------------------# + rbox = np.zeros((box.shape[0], 6)) + rbox[..., :5] = poly2rbox(box[..., :8]) + rbox[..., 5] = box[..., 8] keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) rbox = rbox[keep] - # 查看旋转框是否正确 - # draw = ImageDraw.Draw(image) - # polys = rbox2poly(rbox[..., :5]) - # for poly in polys: - # draw.polygon(xy=list(poly)) - # image.show() + #------------------------------# + # 检查旋转框 + #------------------------------# + if show: + draw = ImageDraw.Draw(image) + polys = rbox2poly(rbox[..., :5]) + for poly in polys: + draw.polygon(xy=list(poly)) + image.show() return image_data, rbox def merge_rboxes(self, rboxes, cutx, cuty): @@ -222,7 +224,7 @@ class YoloDataset(Dataset): merge_rbox = np.array(merge_rbox) return merge_rbox - def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4): + def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4, show=False): h, w = input_shape min_offset_x = self.rand(0.3, 0.7) min_offset_y = self.rand(0.3, 0.7) @@ -248,21 +250,14 @@ class YoloDataset(Dataset): #---------------------------------# # 保存框的位置 #---------------------------------# - box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]]) - #------------------------------# - # 将polygon转换为rbox - #------------------------------# - rbox = np.zeros((box.shape[0], 6)) - rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True) - rbox[..., 5] = box[..., 8] + box = np.array([np.array(list(map(float,box.split(',')))) for box in line_content[1:]]) #---------------------------------# # 是否翻转图片 #---------------------------------# flip = self.rand()<.5 - if flip and len(rbox)>0: + if flip and len(box)>0: image = image.transpose(Image.FLIP_LEFT_RIGHT) - rbox[:, 0] = iw - rbox[:, 0] - rbox[:, 4] *= -1 + box[:, [0,2,4,6]] = iw - box[:, [0,2,4,6]] #------------------------------------------# # 对图像进行缩放并且进行长和宽的扭曲 #------------------------------------------# @@ -301,12 +296,16 @@ class YoloDataset(Dataset): #---------------------------------# # 对rbox进行重新处理 #---------------------------------# - if len(rbox)>0: - np.random.shuffle(rbox) - rbox[:, 0] = rbox[:, 0]*nw/iw + dx - rbox[:, 1] = rbox[:, 1]*nh/ih + dy - rbox[:, 2] = rbox[:, 2]*nw/iw - rbox[:, 3] = rbox[:, 3]*nh/ih + if len(box)>0: + np.random.shuffle(box) + box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx + box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy + #------------------------------# + # 将polygon转换为rbox + #------------------------------# + rbox = np.zeros((box.shape[0], 6)) + rbox[..., :5] = poly2rbox(box[..., :8]) + rbox[..., 5] = box[..., 8] keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \ & (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \ & (rbox[:, 2] > 5) | (rbox[:, 3] > 5) @@ -355,13 +354,16 @@ class YoloDataset(Dataset): # 对框进行进一步的处理 #---------------------------------# new_rboxes = self.merge_rboxes(rbox_datas, cutx, cuty) - # 查看旋转框是否正确 - # newImage = Image.fromarray(new_image) - # draw = ImageDraw.Draw(newImage) - # polys = rbox2poly(new_rboxes[..., :5]) - # for poly in polys: - # draw.polygon(xy=list(poly)) - # newImage.show() + #---------------------------------# + # 检查旋转框 + #---------------------------------# + if show: + new_img = Image.fromarray(new_image) + draw = ImageDraw.Draw(new_img) + polys = rbox2poly(new_rboxes[..., :5]) + for poly in polys: + draw.polygon(xy=list(poly)) + new_img.show() return new_image, new_rboxes def get_random_data_with_MixUp(self, image_1, rbox_1, image_2, rbox_2): diff --git a/utils/utils_rbox.py b/utils/utils_rbox.py index 7daf91d..59ba75b 100644 --- a/utils/utils_rbox.py +++ b/utils/utils_rbox.py @@ -2,130 +2,75 @@ Author: [egrt] Date: 2023-01-30 19:00:28 LastEditors: [egrt] -LastEditTime: 2023-02-06 20:34:05 +LastEditTime: 2023-02-07 17:15:56 +Description: Oriented Bounding Boxes utils +''' + +''' +Author: [egrt] +Date: 2023-01-30 19:00:28 +LastEditors: Egrt +LastEditTime: 2023-02-07 14:39:16 Description: Oriented Bounding Boxes utils ''' import numpy as np +import math pi = np.pi import cv2 import torch -def gaussian_label_cpu(label, num_class, u=0, sig=4.0): - """ - 转换成CSL Labels: - 用高斯窗口函数根据角度θ的周期性赋予gt labels同样的周期性,使得损失函数在计算边界处时可以做到“差值很大但loss很小”; - 并且使得其labels具有环形特征,能够反映各个θ之间的角度距离 - Args: - label (float32):[1], theta class - num_theta_class (int): [1], theta class num - u (float32):[1], μ in gaussian function - sig (float32):[1], σ in gaussian function, which is window radius for Circular Smooth Label - Returns: - csl_label (array): [num_theta_class], gaussian function smooth label - """ - x = np.arange(-num_class/2, num_class/2) - y_sig = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) - index = int(num_class/2 - label) - return np.concatenate([y_sig[index:], - y_sig[:index]], axis=0) - -def regular_theta(theta, mode='180', start=-pi/2): - """ - limit theta ∈ [-pi/2, pi/2) - """ - assert mode in ['360', '180'] - cycle = 2 * pi if mode == '360' else pi - - theta = theta - start - theta = theta % cycle - return theta + start - -def poly2rbox(polys, num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False): +def poly2rbox(polys): """ Trans poly format to rbox format. Args: polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4]) - num_cls_thata (int): [1], theta class num - radius (float32): [1], window radius for Circular Smooth Label - use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180) Returns: - use_gaussian True: - rboxes (array): - csl_labels (array): (num_gts, num_cls_thata) - elif - rboxes (array): (num_gts, [cx cy l s θ]) + rboxes (array): (num_gts, [cx cy l s θ]) """ assert polys.shape[-1] == 8 - if use_gaussian: - csl_labels = [] rboxes = [] for poly in polys: poly = np.float32(poly.reshape(4, 2)) (x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90] - angle = -angle # θ ∈ [-90, 0] theta = angle / 180 * pi # 转为pi制 - # trans opencv format to longedge format θ ∈ [-pi/2, pi/2] - if w != max(w, h): + if w < h: w, h = h, w - theta += pi/2 - theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2) - angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180) - - if not use_pi: # 采用angle弧度制 θ ∈ [0, 180) - rboxes.append([x, y, w, h, angle]) - else: # 采用pi制 - rboxes.append([x, y, w, h, theta]) - if use_gaussian: - csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius) - csl_labels.append(csl_label) - if use_gaussian: - return np.array(rboxes), np.array(csl_labels) + theta += np.pi / 2 + while not np.pi / 2 > theta >= -np.pi / 2: + if theta >= np.pi / 2: + theta -= np.pi + else: + theta += np.pi + assert np.pi / 2 > theta >= -np.pi / 2 + rboxes.append([x, y, w, h, theta]) return np.array(rboxes) - -def rbox2poly(obboxes): - """ - Trans rbox format to poly format. - Args: - rboxes (array/tensor): (num_gts, [cx cy l s θ]) θ∈[-pi/2, pi/2) +def poly2obb_np_le90(poly): + """Convert polygons to oriented bounding boxes. + Args: + polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] Returns: - polys (array/tensor): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4]) + obbs (ndarray): [x_ctr,y_ctr,w,h,angle] """ - if isinstance(obboxes, torch.Tensor): - center, w, h, theta = obboxes[:, :2], obboxes[:, 2:3], obboxes[:, 3:4], obboxes[:, 4:5] - Cos, Sin = torch.cos(theta), torch.sin(theta) - - vector1 = torch.cat( - (w/2 * Cos, -w/2 * Sin), dim=-1) - vector2 = torch.cat( - (-h/2 * Sin, -h/2 * Cos), dim=-1) - point1 = center + vector1 + vector2 - point2 = center + vector1 - vector2 - point3 = center - vector1 - vector2 - point4 = center - vector1 + vector2 - order = obboxes.shape[:-1] - return torch.cat( - (point1, point2, point3, point4), dim=-1).reshape(*order, 8) - else: - center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1) - Cos, Sin = np.cos(theta), np.sin(theta) - - vector1 = np.concatenate( - [w/2 * Cos, -w/2 * Sin], axis=-1) - vector2 = np.concatenate( - [-h/2 * Sin, -h/2 * Cos], axis=-1) - - point1 = center + vector1 + vector2 - point2 = center + vector1 - vector2 - point3 = center - vector1 - vector2 - point4 = center - vector1 + vector2 - order = obboxes.shape[:-1] - return np.concatenate( - [point1, point2, point3, point4], axis=-1).reshape(*order, 8) - - + bboxps = np.array(poly).reshape((4, 2)) + rbbox = cv2.minAreaRect(bboxps) + x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2] + if w < 2 or h < 2: + return + a = a / 180 * np.pi + if w < h: + w, h = h, w + a += np.pi / 2 + while not np.pi / 2 > a >= -np.pi / 2: + if a >= np.pi / 2: + a -= np.pi + else: + a += np.pi + assert np.pi / 2 > a >= -np.pi / 2 + return x, y, w, h, a + def poly2hbb(polys): """ Trans poly format to hbb format @@ -162,21 +107,82 @@ def poly2hbb(polys): hbboxes = np.concatenate((x_ctr, y_ctr, w, h), axis=1) return hbboxes -def poly_filter(polys, h, w): +def rbox2poly(obboxes): + """Convert oriented bounding boxes to polygons. + Args: + obbs (ndarray): [x_ctr,y_ctr,w,h,angle] + Returns: + polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3] """ - Filter the poly labels which is out of the image. + try: + center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1) + except: + results = np.stack([0., 0., 0., 0., 0., 0., 0., 0.], axis=-1) + return results.reshape(1, -1) + Cos, Sin = np.cos(theta), np.sin(theta) + vector1 = np.concatenate([w / 2 * Cos, w / 2 * Sin], axis=-1) + vector2 = np.concatenate([-h / 2 * Sin, h / 2 * Cos], axis=-1) + point1 = center - vector1 - vector2 + point2 = center + vector1 - vector2 + point3 = center + vector1 + vector2 + point4 = center - vector1 + vector2 + polys = np.concatenate([point1, point2, point3, point4], axis=-1) + polys = get_best_begin_point(polys) + return polys + +def cal_line_length(point1, point2): + """Calculate the length of line. Args: - polys (array): (num, 8) + point1 (List): [x,y] + point2 (List): [x,y] + Returns: + length (float) + """ + return math.sqrt( + math.pow(point1[0] - point2[0], 2) + + math.pow(point1[1] - point2[1], 2)) + - Return: - keep_masks (array): (num) +def get_best_begin_point_single(coordinate): + """Get the best begin point of the single polygon. + Args: + coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score] + Returns: + reorder coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score] + """ + x1, y1, x2, y2, x3, y3, x4, y4 = coordinate + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + xmax = max(x1, x2, x3, x4) + ymax = max(y1, y2, y3, y4) + combine = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], + [[x2, y2], [x3, y3], [x4, y4], [x1, y1]], + [[x3, y3], [x4, y4], [x1, y1], [x2, y2]], + [[x4, y4], [x1, y1], [x2, y2], [x3, y3]]] + dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + force = 100000000.0 + force_flag = 0 + for i in range(4): + temp_force = cal_line_length(combine[i][0], dst_coordinate[0]) \ + + cal_line_length(combine[i][1], dst_coordinate[1]) \ + + cal_line_length(combine[i][2], dst_coordinate[2]) \ + + cal_line_length(combine[i][3], dst_coordinate[3]) + if temp_force < force: + force = temp_force + force_flag = i + if force_flag != 0: + pass + return np.hstack( + (np.array(combine[force_flag]).reshape(8))) + + +def get_best_begin_point(coordinates): + """Get the best begin points of polygons. + Args: + coordinate (ndarray): shape(n, 8). + Returns: + reorder coordinate (ndarray): shape(n, 8). """ - x = polys[:, 0::2] # (num, 4) - y = polys[:, 1::2] - x_max = np.amax(x, axis=1) # (num) - x_min = np.amin(x, axis=1) - y_max = np.amax(y, axis=1) - y_min = np.amin(y, axis=1) - x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num) - keep_masks = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h) - return keep_masks \ No newline at end of file + coordinates = list(map(get_best_begin_point_single, coordinates.tolist())) + coordinates = np.array(coordinates) + return coordinates \ No newline at end of file -- GitLab