提交 bc39df29 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

支持hrsc数据集

上级 a6b827ed
import os
import random
import xml.etree.ElementTree as ET
import numpy as np
from utils.utils_rbox import *
from utils.utils import get_classes
#--------------------------------------------------------------------------------------------------------------------------------#
# annotation_mode用于指定该文件运行时计算的内容
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode = 0
#-------------------------------------------------------------------#
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
# 与训练和预测所用的classes_path一致即可
# 如果生成的2007_train.txt里面没有目标信息
# 那么就是因为classes没有设定正确
# 仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path = 'model_data/hrsc_classes.txt'
#--------------------------------------------------------------------------------------------------------------------------------#
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
# 仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
trainval_percent = 0.9
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'
VOCdevkit_sets = [('2007_HRSC', 'train'), ('2007_HRSC', 'val')]
classes, _ = get_classes(classes_path)
#-------------------------------------------------------#
# 统计目标数量
#-------------------------------------------------------#
photo_nums = np.zeros(len(VOCdevkit_sets))
nums = np.zeros(len(classes))
def convert_annotation(year, image_id, list_file):
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
tree=ET.parse(in_file)
root = tree.getroot().find('HRSC_Objects')
for obj in root.iter('HRSC_Object'):
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
if obj.find('mbox_cx')==None:
continue
cls_id = classes.index(cls)
cx = float(obj.find('mbox_cx').text)
cy = float(obj.find('mbox_cy').text)
w = float(obj.find('mbox_w').text)
h = float(obj.find('mbox_h').text)
angle = float(obj.find('mbox_ang').text)
b = np.array([[cx, cy, w, h, angle]], dtype=np.float32)
b = rbox2poly(b)[0]
b = (b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7])
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
nums[classes.index(cls)] = nums[classes.index(cls)] + 1
if __name__ == "__main__":
random.seed(0)
if " " in os.path.abspath(VOCdevkit_path):
raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.")
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007_HRSC/Annotations')
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007_HRSC/ImageSets/Main')
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval= random.sample(list,tv)
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
if annotation_mode == 0 or annotation_mode == 2:
print("Generate 2007_train.txt and 2007_val.txt for train.")
type_index = 0
for year, image_set in VOCdevkit_sets:
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
convert_annotation(year, image_id, list_file)
list_file.write('\n')
photo_nums[type_index] = len(image_ids)
type_index += 1
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")
def printTable(List1, List2):
for i in range(len(List1[0])):
print("|", end=' ')
for j in range(len(List1)):
print(List1[j][i].rjust(int(List2[j])), end=' ')
print("|", end=' ')
print()
str_nums = [str(int(x)) for x in nums]
tableData = [
classes, str_nums
]
colWidths = [0]*len(tableData)
len1 = 0
for i in range(len(tableData)):
for j in range(len(tableData[i])):
if len(tableData[i][j]) > colWidths[i]:
colWidths[i] = len(tableData[i][j])
printTable(tableData, colWidths)
if photo_nums[0] <= 500:
print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
if np.sum(nums) == 0:
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("(重要的事情说三遍)。")
......@@ -81,7 +81,7 @@ class YoloDataset(Dataset):
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True, show=False):
line = annotation_line.split()
#------------------------------#
# 读取图像并转换成RGB图像
......@@ -96,13 +96,7 @@ class YoloDataset(Dataset):
#------------------------------#
# 获得预测框
#------------------------------#
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True)
rbox[..., 5] = box[..., 8]
box = np.array([np.array(list(map(float,box.split(',')))) for box in line[1:]])
if not random:
scale = min(w/iw, h/ih)
......@@ -122,17 +116,20 @@ class YoloDataset(Dataset):
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(rbox)>0:
np.random.shuffle(rbox)
rbox[:, 0] = rbox[:, 0]*nw/iw + dx
rbox[:, 1] = rbox[:, 1]*nh/ih + dy
rbox[:, 2] = rbox[:, 2]*nw/iw
rbox[:, 3] = rbox[:, 3]*nh/ih
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx
box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8])
rbox[..., 5] = box[..., 8]
keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \
& (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \
& (rbox[:, 2] > 5) | (rbox[:, 3] > 5)
rbox = rbox[keep]
return image_data, rbox
#------------------------------------------#
......@@ -186,25 +183,30 @@ class YoloDataset(Dataset):
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(rbox)>0:
np.random.shuffle(rbox)
rbox[:, 0] = rbox[:, 0]*nw/iw + dx
rbox[:, 1] = rbox[:, 1]*nh/ih + dy
rbox[:, 2] = rbox[:, 2]*nw/iw
rbox[:, 3] = rbox[:, 3]*nh/ih
if flip:
rbox[:, 0] = w - rbox[:, 0]
rbox[:, 4] *= -1
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx
box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy
if flip: box[:, [0,2,4,6]] = w - box[:, [0,2,4,6]]
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8])
rbox[..., 5] = box[..., 8]
keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \
& (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \
& (rbox[:, 2] > 5) | (rbox[:, 3] > 5)
rbox = rbox[keep]
# 查看旋转框是否正确
# draw = ImageDraw.Draw(image)
# polys = rbox2poly(rbox[..., :5])
# for poly in polys:
# draw.polygon(xy=list(poly))
# image.show()
#------------------------------#
# 检查旋转框
#------------------------------#
if show:
draw = ImageDraw.Draw(image)
polys = rbox2poly(rbox[..., :5])
for poly in polys:
draw.polygon(xy=list(poly))
image.show()
return image_data, rbox
def merge_rboxes(self, rboxes, cutx, cuty):
......@@ -222,7 +224,7 @@ class YoloDataset(Dataset):
merge_rbox = np.array(merge_rbox)
return merge_rbox
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4, show=False):
h, w = input_shape
min_offset_x = self.rand(0.3, 0.7)
min_offset_y = self.rand(0.3, 0.7)
......@@ -248,21 +250,14 @@ class YoloDataset(Dataset):
#---------------------------------#
# 保存框的位置
#---------------------------------#
box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8], use_pi=True)
rbox[..., 5] = box[..., 8]
box = np.array([np.array(list(map(float,box.split(',')))) for box in line_content[1:]])
#---------------------------------#
# 是否翻转图片
#---------------------------------#
flip = self.rand()<.5
if flip and len(rbox)>0:
if flip and len(box)>0:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
rbox[:, 0] = iw - rbox[:, 0]
rbox[:, 4] *= -1
box[:, [0,2,4,6]] = iw - box[:, [0,2,4,6]]
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
......@@ -301,12 +296,16 @@ class YoloDataset(Dataset):
#---------------------------------#
# 对rbox进行重新处理
#---------------------------------#
if len(rbox)>0:
np.random.shuffle(rbox)
rbox[:, 0] = rbox[:, 0]*nw/iw + dx
rbox[:, 1] = rbox[:, 1]*nh/ih + dy
rbox[:, 2] = rbox[:, 2]*nw/iw
rbox[:, 3] = rbox[:, 3]*nh/ih
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2,4,6]] = box[:, [0,2,4,6]]*nw/iw + dx
box[:, [1,3,5,7]] = box[:, [1,3,5,7]]*nh/ih + dy
#------------------------------#
# 将polygon转换为rbox
#------------------------------#
rbox = np.zeros((box.shape[0], 6))
rbox[..., :5] = poly2rbox(box[..., :8])
rbox[..., 5] = box[..., 8]
keep = (rbox[:, 0] >= 0) & (rbox[:, 0] < w) \
& (rbox[:, 1] >= 0) & (rbox[:, 0] < h) \
& (rbox[:, 2] > 5) | (rbox[:, 3] > 5)
......@@ -355,13 +354,16 @@ class YoloDataset(Dataset):
# 对框进行进一步的处理
#---------------------------------#
new_rboxes = self.merge_rboxes(rbox_datas, cutx, cuty)
# 查看旋转框是否正确
# newImage = Image.fromarray(new_image)
# draw = ImageDraw.Draw(newImage)
# polys = rbox2poly(new_rboxes[..., :5])
# for poly in polys:
# draw.polygon(xy=list(poly))
# newImage.show()
#---------------------------------#
# 检查旋转框
#---------------------------------#
if show:
new_img = Image.fromarray(new_image)
draw = ImageDraw.Draw(new_img)
polys = rbox2poly(new_rboxes[..., :5])
for poly in polys:
draw.polygon(xy=list(poly))
new_img.show()
return new_image, new_rboxes
def get_random_data_with_MixUp(self, image_1, rbox_1, image_2, rbox_2):
......
......@@ -2,130 +2,75 @@
Author: [egrt]
Date: 2023-01-30 19:00:28
LastEditors: [egrt]
LastEditTime: 2023-02-06 20:34:05
LastEditTime: 2023-02-07 17:15:56
Description: Oriented Bounding Boxes utils
'''
'''
Author: [egrt]
Date: 2023-01-30 19:00:28
LastEditors: Egrt
LastEditTime: 2023-02-07 14:39:16
Description: Oriented Bounding Boxes utils
'''
import numpy as np
import math
pi = np.pi
import cv2
import torch
def gaussian_label_cpu(label, num_class, u=0, sig=4.0):
"""
转换成CSL Labels:
用高斯窗口函数根据角度θ的周期性赋予gt labels同样的周期性,使得损失函数在计算边界处时可以做到“差值很大但loss很小”;
并且使得其labels具有环形特征,能够反映各个θ之间的角度距离
Args:
label (float32):[1], theta class
num_theta_class (int): [1], theta class num
u (float32):[1], μ in gaussian function
sig (float32):[1], σ in gaussian function, which is window radius for Circular Smooth Label
Returns:
csl_label (array): [num_theta_class], gaussian function smooth label
"""
x = np.arange(-num_class/2, num_class/2)
y_sig = np.exp(-(x - u) ** 2 / (2 * sig ** 2))
index = int(num_class/2 - label)
return np.concatenate([y_sig[index:],
y_sig[:index]], axis=0)
def regular_theta(theta, mode='180', start=-pi/2):
"""
limit theta ∈ [-pi/2, pi/2)
"""
assert mode in ['360', '180']
cycle = 2 * pi if mode == '360' else pi
theta = theta - start
theta = theta % cycle
return theta + start
def poly2rbox(polys, num_cls_thata=180, radius=6.0, use_pi=False, use_gaussian=False):
def poly2rbox(polys):
"""
Trans poly format to rbox format.
Args:
polys (array): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
num_cls_thata (int): [1], theta class num
radius (float32): [1], window radius for Circular Smooth Label
use_pi (bool): True θ∈[-pi/2, pi/2) , False θ∈[0, 180)
Returns:
use_gaussian True:
rboxes (array):
csl_labels (array): (num_gts, num_cls_thata)
elif
rboxes (array): (num_gts, [cx cy l s θ])
rboxes (array): (num_gts, [cx cy l s θ])
"""
assert polys.shape[-1] == 8
if use_gaussian:
csl_labels = []
rboxes = []
for poly in polys:
poly = np.float32(poly.reshape(4, 2))
(x, y), (w, h), angle = cv2.minAreaRect(poly) # θ ∈ [0, 90]
angle = -angle # θ ∈ [-90, 0]
theta = angle / 180 * pi # 转为pi制
# trans opencv format to longedge format θ ∈ [-pi/2, pi/2]
if w != max(w, h):
if w < h:
w, h = h, w
theta += pi/2
theta = regular_theta(theta) # limit theta ∈ [-pi/2, pi/2)
angle = (theta * 180 / pi) + 90 # θ ∈ [0, 180)
if not use_pi: # 采用angle弧度制 θ ∈ [0, 180)
rboxes.append([x, y, w, h, angle])
else: # 采用pi制
rboxes.append([x, y, w, h, theta])
if use_gaussian:
csl_label = gaussian_label_cpu(label=angle, num_class=num_cls_thata, u=0, sig=radius)
csl_labels.append(csl_label)
if use_gaussian:
return np.array(rboxes), np.array(csl_labels)
theta += np.pi / 2
while not np.pi / 2 > theta >= -np.pi / 2:
if theta >= np.pi / 2:
theta -= np.pi
else:
theta += np.pi
assert np.pi / 2 > theta >= -np.pi / 2
rboxes.append([x, y, w, h, theta])
return np.array(rboxes)
def rbox2poly(obboxes):
"""
Trans rbox format to poly format.
Args:
rboxes (array/tensor): (num_gts, [cx cy l s θ]) θ∈[-pi/2, pi/2)
def poly2obb_np_le90(poly):
"""Convert polygons to oriented bounding boxes.
Args:
polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
Returns:
polys (array/tensor): (num_gts, [x1 y1 x2 y2 x3 y3 x4 y4])
obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
"""
if isinstance(obboxes, torch.Tensor):
center, w, h, theta = obboxes[:, :2], obboxes[:, 2:3], obboxes[:, 3:4], obboxes[:, 4:5]
Cos, Sin = torch.cos(theta), torch.sin(theta)
vector1 = torch.cat(
(w/2 * Cos, -w/2 * Sin), dim=-1)
vector2 = torch.cat(
(-h/2 * Sin, -h/2 * Cos), dim=-1)
point1 = center + vector1 + vector2
point2 = center + vector1 - vector2
point3 = center - vector1 - vector2
point4 = center - vector1 + vector2
order = obboxes.shape[:-1]
return torch.cat(
(point1, point2, point3, point4), dim=-1).reshape(*order, 8)
else:
center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1)
Cos, Sin = np.cos(theta), np.sin(theta)
vector1 = np.concatenate(
[w/2 * Cos, -w/2 * Sin], axis=-1)
vector2 = np.concatenate(
[-h/2 * Sin, -h/2 * Cos], axis=-1)
point1 = center + vector1 + vector2
point2 = center + vector1 - vector2
point3 = center - vector1 - vector2
point4 = center - vector1 + vector2
order = obboxes.shape[:-1]
return np.concatenate(
[point1, point2, point3, point4], axis=-1).reshape(*order, 8)
bboxps = np.array(poly).reshape((4, 2))
rbbox = cv2.minAreaRect(bboxps)
x, y, w, h, a = rbbox[0][0], rbbox[0][1], rbbox[1][0], rbbox[1][1], rbbox[2]
if w < 2 or h < 2:
return
a = a / 180 * np.pi
if w < h:
w, h = h, w
a += np.pi / 2
while not np.pi / 2 > a >= -np.pi / 2:
if a >= np.pi / 2:
a -= np.pi
else:
a += np.pi
assert np.pi / 2 > a >= -np.pi / 2
return x, y, w, h, a
def poly2hbb(polys):
"""
Trans poly format to hbb format
......@@ -162,21 +107,82 @@ def poly2hbb(polys):
hbboxes = np.concatenate((x_ctr, y_ctr, w, h), axis=1)
return hbboxes
def poly_filter(polys, h, w):
def rbox2poly(obboxes):
"""Convert oriented bounding boxes to polygons.
Args:
obbs (ndarray): [x_ctr,y_ctr,w,h,angle]
Returns:
polys (ndarray): [x0,y0,x1,y1,x2,y2,x3,y3]
"""
Filter the poly labels which is out of the image.
try:
center, w, h, theta = np.split(obboxes, (2, 3, 4), axis=-1)
except:
results = np.stack([0., 0., 0., 0., 0., 0., 0., 0.], axis=-1)
return results.reshape(1, -1)
Cos, Sin = np.cos(theta), np.sin(theta)
vector1 = np.concatenate([w / 2 * Cos, w / 2 * Sin], axis=-1)
vector2 = np.concatenate([-h / 2 * Sin, h / 2 * Cos], axis=-1)
point1 = center - vector1 - vector2
point2 = center + vector1 - vector2
point3 = center + vector1 + vector2
point4 = center - vector1 + vector2
polys = np.concatenate([point1, point2, point3, point4], axis=-1)
polys = get_best_begin_point(polys)
return polys
def cal_line_length(point1, point2):
"""Calculate the length of line.
Args:
polys (array): (num, 8)
point1 (List): [x,y]
point2 (List): [x,y]
Returns:
length (float)
"""
return math.sqrt(
math.pow(point1[0] - point2[0], 2) +
math.pow(point1[1] - point2[1], 2))
Return:
keep_masks (array): (num)
def get_best_begin_point_single(coordinate):
"""Get the best begin point of the single polygon.
Args:
coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score]
Returns:
reorder coordinate (List): [x1, y1, x2, y2, x3, y3, x4, y4, score]
"""
x1, y1, x2, y2, x3, y3, x4, y4 = coordinate
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
xmax = max(x1, x2, x3, x4)
ymax = max(y1, y2, y3, y4)
combine = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
[[x2, y2], [x3, y3], [x4, y4], [x1, y1]],
[[x3, y3], [x4, y4], [x1, y1], [x2, y2]],
[[x4, y4], [x1, y1], [x2, y2], [x3, y3]]]
dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
force = 100000000.0
force_flag = 0
for i in range(4):
temp_force = cal_line_length(combine[i][0], dst_coordinate[0]) \
+ cal_line_length(combine[i][1], dst_coordinate[1]) \
+ cal_line_length(combine[i][2], dst_coordinate[2]) \
+ cal_line_length(combine[i][3], dst_coordinate[3])
if temp_force < force:
force = temp_force
force_flag = i
if force_flag != 0:
pass
return np.hstack(
(np.array(combine[force_flag]).reshape(8)))
def get_best_begin_point(coordinates):
"""Get the best begin points of polygons.
Args:
coordinate (ndarray): shape(n, 8).
Returns:
reorder coordinate (ndarray): shape(n, 8).
"""
x = polys[:, 0::2] # (num, 4)
y = polys[:, 1::2]
x_max = np.amax(x, axis=1) # (num)
x_min = np.amin(x, axis=1)
y_max = np.amax(y, axis=1)
y_min = np.amin(y, axis=1)
x_ctr, y_ctr = (x_max + x_min) / 2.0, (y_max + y_min) / 2.0 # (num)
keep_masks = (x_ctr > 0) & (x_ctr < w) & (y_ctr > 0) & (y_ctr < h)
return keep_masks
\ No newline at end of file
coordinates = list(map(get_best_begin_point_single, coordinates.tolist()))
coordinates = np.array(coordinates)
return coordinates
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册