提交 72f1c6ff 编写于 作者: _白鹭先生_'s avatar _白鹭先生_

修复马赛克数据增强

上级 e7dcbb0a
...@@ -110,7 +110,7 @@ if __name__ == "__main__": ...@@ -110,7 +110,7 @@ if __name__ == "__main__":
# 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。 # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。
# 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。 # 如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
#----------------------------------------------------------------------------------------------------------------------------# #----------------------------------------------------------------------------------------------------------------------------#
pretrained = False pretrained = True
#------------------------------------------------------------------# #------------------------------------------------------------------#
# mosaic 马赛克数据增强。 # mosaic 马赛克数据增强。
# mosaic_prob 每个step有多少概率使用mosaic数据增强,默认50%。 # mosaic_prob 每个step有多少概率使用mosaic数据增强,默认50%。
...@@ -124,7 +124,7 @@ if __name__ == "__main__": ...@@ -124,7 +124,7 @@ if __name__ == "__main__":
# 当mosaic=True时,本代码会在special_aug_ratio范围内开启mosaic。 # 当mosaic=True时,本代码会在special_aug_ratio范围内开启mosaic。
# 默认为前70%个epoch,100个世代会开启70个世代。 # 默认为前70%个epoch,100个世代会开启70个世代。
#------------------------------------------------------------------# #------------------------------------------------------------------#
mosaic = False mosaic = True
mosaic_prob = 0.5 mosaic_prob = 0.5
mixup = False mixup = False
mixup_prob = 0.5 mixup_prob = 0.5
...@@ -186,7 +186,7 @@ if __name__ == "__main__": ...@@ -186,7 +186,7 @@ if __name__ == "__main__":
# Freeze_Train 是否进行冻结训练 # Freeze_Train 是否进行冻结训练
# 默认先冻结主干训练后解冻训练。 # 默认先冻结主干训练后解冻训练。
#------------------------------------------------------------------# #------------------------------------------------------------------#
Freeze_Train = False Freeze_Train = True
#------------------------------------------------------------------# #------------------------------------------------------------------#
# 其它训练参数:学习率、优化器、学习率下降有关 # 其它训练参数:学习率、优化器、学习率下降有关
......
...@@ -41,18 +41,18 @@ class YoloDataset(Dataset): ...@@ -41,18 +41,18 @@ class YoloDataset(Dataset):
# 训练时进行数据的随机增强 # 训练时进行数据的随机增强
# 验证时不进行数据的随机增强 # 验证时不进行数据的随机增强
#---------------------------------------------------# #---------------------------------------------------#
if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio: # if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
lines = sample(self.annotation_lines, 3) lines = sample(self.annotation_lines, 3)
lines.append(self.annotation_lines[index]) lines.append(self.annotation_lines[index])
shuffle(lines) shuffle(lines)
image, rbox = self.get_random_data_with_Mosaic(lines, self.input_shape) image, rbox = self.get_random_data_with_Mosaic(lines, self.input_shape)
if self.mixup and self.rand() < self.mixup_prob: # if self.mixup and self.rand() < self.mixup_prob:
lines = sample(self.annotation_lines, 1) # lines = sample(self.annotation_lines, 1)
image_2, rbox_2 = self.get_random_data(lines[0], self.input_shape, random = self.train) # image_2, rbox_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
image, rbox = self.get_random_data_with_MixUp(image, rbox, image_2, rbox_2) # image, rbox = self.get_random_data_with_MixUp(image, rbox, image_2, rbox_2)
else: # else:
image, rbox = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train) # image, rbox = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
rbox = np.array(rbox, dtype=np.float32) rbox = np.array(rbox, dtype=np.float32)
...@@ -193,51 +193,20 @@ class YoloDataset(Dataset): ...@@ -193,51 +193,20 @@ class YoloDataset(Dataset):
image.show() image.show()
return image_data, rbox return image_data, rbox
def merge_bboxes(self, bboxes, cutx, cuty): def merge_rboxes(self, rboxes, cutx, cuty):
merge_bbox = [] merge_rbox = []
for i in range(len(bboxes)): for i in range(len(rboxes)):
for box in bboxes[i]: for rbox in rboxes[i]:
tmp_box = [] tmp_rbox = []
x1, y1, x2, y2 = box[0], box[1], box[2], box[3] xc, yc, w, h = rbox[0], rbox[1], rbox[2], rbox[3]
tmp_rbox.append(xc)
if i == 0: tmp_rbox.append(yc)
if y1 > cuty or x1 > cutx: tmp_rbox.append(h)
continue tmp_rbox.append(w)
if y2 >= cuty and y1 <= cuty: tmp_rbox.append(rbox[-1])
y2 = cuty merge_rbox.append(rbox)
if x2 >= cutx and x1 <= cutx: merge_rbox = np.array(merge_rbox)
x2 = cutx return merge_rbox
if i == 1:
if y2 < cuty or x1 > cutx:
continue
if y2 >= cuty and y1 <= cuty:
y1 = cuty
if x2 >= cutx and x1 <= cutx:
x2 = cutx
if i == 2:
if y2 < cuty or x2 < cutx:
continue
if y2 >= cuty and y1 <= cuty:
y1 = cuty
if x2 >= cutx and x1 <= cutx:
x1 = cutx
if i == 3:
if y1 > cuty or x2 < cutx:
continue
if y2 >= cuty and y1 <= cuty:
y2 = cuty
if x2 >= cutx and x1 <= cutx:
x1 = cutx
tmp_box.append(x1)
tmp_box.append(y1)
tmp_box.append(x2)
tmp_box.append(y2)
tmp_box.append(box[-1])
merge_bbox.append(tmp_box)
return merge_bbox
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4): def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
h, w = input_shape h, w = input_shape
...@@ -245,7 +214,7 @@ class YoloDataset(Dataset): ...@@ -245,7 +214,7 @@ class YoloDataset(Dataset):
min_offset_y = self.rand(0.3, 0.7) min_offset_y = self.rand(0.3, 0.7)
image_datas = [] image_datas = []
box_datas = [] rbox_datas = []
index = 0 index = 0
for line in annotation_line: for line in annotation_line:
#---------------------------------# #---------------------------------#
...@@ -314,25 +283,21 @@ class YoloDataset(Dataset): ...@@ -314,25 +283,21 @@ class YoloDataset(Dataset):
image_data = np.array(new_image) image_data = np.array(new_image)
index = index + 1 index = index + 1
box_data = [] rbox_data = []
#---------------------------------# #---------------------------------#
# 对box进行重新处理 # 对rbox进行重新处理
#---------------------------------# #---------------------------------#
if len(box)>0: if len(rbox)>0:
np.random.shuffle(box) np.random.shuffle(rbox)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx rbox[:, 0] = rbox[:, 0]*nw/w + dx/w
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy rbox[:, 1] = rbox[:, 1]*nh/h + dy/h
box[:, 0:2][box[:, 0:2]<0] = 0 rbox[:, 2] = rbox[:, 2]*nw/w
box[:, 2][box[:, 2]>w] = w rbox[:, 3] = rbox[:, 3]*nh/h
box[:, 3][box[:, 3]>h] = h rbox_data = np.zeros((len(rbox),6))
box_w = box[:, 2] - box[:, 0] rbox_data[:len(rbox)] = rbox
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)]
box_data = np.zeros((len(box),5))
box_data[:len(box)] = box
image_datas.append(image_data) image_datas.append(image_data)
box_datas.append(box_data) rbox_datas.append(rbox_data)
#---------------------------------# #---------------------------------#
# 将图片分割,放在一起 # 将图片分割,放在一起
...@@ -371,9 +336,15 @@ class YoloDataset(Dataset): ...@@ -371,9 +336,15 @@ class YoloDataset(Dataset):
#---------------------------------# #---------------------------------#
# 对框进行进一步的处理 # 对框进行进一步的处理
#---------------------------------# #---------------------------------#
new_boxes = self.merge_bboxes(box_datas, cutx, cuty) new_rboxes = self.merge_rboxes(rbox_datas, cutx, cuty)
# 查看旋转框是否正确
return new_image, new_boxes # newImage = Image.fromarray(new_image)
# draw = ImageDraw.Draw(newImage)
# polys = rbox2poly(new_rboxes[..., :5])*w
# for poly in polys:
# draw.polygon(xy=list(poly))
# newImage.show()
return new_image, new_rboxes
def get_random_data_with_MixUp(self, image_1, rbox_1, image_2, rbox_2): def get_random_data_with_MixUp(self, image_1, rbox_1, image_2, rbox_2):
new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5 new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册