未验证 提交 9c680e5b 编写于 作者: G Guanghua Yu 提交者: GitHub

fix is_crowd and difficult in Mosaic (#6150)

上级 18c2099a
...@@ -3184,7 +3184,7 @@ class Mosaic(BaseOperator): ...@@ -3184,7 +3184,7 @@ class Mosaic(BaseOperator):
if np.random.uniform(0., 1.) > self.prob: if np.random.uniform(0., 1.) > self.prob:
return sample[0] return sample[0]
mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], [] mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd, mosaic_difficult = [], [], [], []
input_h, input_w = self.input_dim input_h, input_w = self.input_dim
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
...@@ -3217,21 +3217,35 @@ class Mosaic(BaseOperator): ...@@ -3217,21 +3217,35 @@ class Mosaic(BaseOperator):
_gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw _gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw
_gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh _gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh
is_crowd = sp['is_crowd'] if 'is_crowd' in sp else np.zeros(
(len(_gt_bbox), 1), dtype=np.int32)
mosaic_gt_bbox.append(_gt_bbox) mosaic_gt_bbox.append(_gt_bbox)
mosaic_gt_class.append(sp['gt_class']) mosaic_gt_class.append(sp['gt_class'])
mosaic_is_crowd.append(is_crowd) if 'is_crowd' in sp:
mosaic_is_crowd.append(sp['is_crowd'])
if 'difficult' in sp:
mosaic_difficult.append(sp['difficult'])
# 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd]) # 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd])
if len(mosaic_gt_bbox): if len(mosaic_gt_bbox):
mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0) mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0)
mosaic_gt_class = np.concatenate(mosaic_gt_class, 0) mosaic_gt_class = np.concatenate(mosaic_gt_class, 0)
mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0) if mosaic_is_crowd:
mosaic_labels = np.concatenate([ mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0)
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype), mosaic_labels = np.concatenate([
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype) mosaic_gt_bbox,
], 1) mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype)
], 1)
elif mosaic_difficult:
mosaic_difficult = np.concatenate(mosaic_difficult, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox,
mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_difficult.astype(mosaic_gt_bbox.dtype)
], 1)
else:
mosaic_labels = np.concatenate([
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype)
], 1)
if self.remove_outside_box: if self.remove_outside_box:
# for MOT dataset # for MOT dataset
flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w
...@@ -3268,11 +3282,23 @@ class Mosaic(BaseOperator): ...@@ -3268,11 +3282,23 @@ class Mosaic(BaseOperator):
random.random() < self.mixup_prob): random.random() < self.mixup_prob):
sample_mixup = sample[4] sample_mixup = sample[4]
mixup_img = sample_mixup['image'] mixup_img = sample_mixup['image']
cp_labels = np.concatenate([ if 'is_crowd' in sample_mixup:
sample_mixup['gt_bbox'], cp_labels = np.concatenate([
sample_mixup['gt_class'].astype(mosaic_labels.dtype), sample_mixup['gt_bbox'],
sample_mixup['is_crowd'].astype(mosaic_labels.dtype) sample_mixup['gt_class'].astype(mosaic_labels.dtype),
], 1) sample_mixup['is_crowd'].astype(mosaic_labels.dtype)
], 1)
elif 'difficult' in sample_mixup:
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['difficult'].astype(mosaic_labels.dtype)
], 1)
else:
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype)
], 1)
mosaic_img, mosaic_labels = self.mixup_augment( mosaic_img, mosaic_labels = self.mixup_augment(
mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img) mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img)
...@@ -3284,7 +3310,10 @@ class Mosaic(BaseOperator): ...@@ -3284,7 +3310,10 @@ class Mosaic(BaseOperator):
sample0['im_shape'][1] = sample0['w'] sample0['im_shape'][1] = sample0['w']
sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32) sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32)
sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32) sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32)
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32) if 'is_crowd' in sample[0]:
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32)
if 'difficult' in sample[0]:
sample0['difficult'] = mosaic_labels[:, 5:6].astype(np.float32)
return sample0 return sample0
def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels, def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels,
...@@ -3351,9 +3380,12 @@ class Mosaic(BaseOperator): ...@@ -3351,9 +3380,12 @@ class Mosaic(BaseOperator):
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h) cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h)
cls_labels = cp_labels[:, 4:5].copy() cls_labels = cp_labels[:, 4:5].copy()
crd_labels = cp_labels[:, 5:6].copy()
box_labels = cp_bboxes_transformed_np box_labels = cp_bboxes_transformed_np
labels = np.hstack((box_labels, cls_labels, crd_labels)) if cp_labels.shape[-1] == 6:
crd_labels = cp_labels[:, 5:6].copy()
labels = np.hstack((box_labels, cls_labels, crd_labels))
else:
labels = np.hstack((box_labels, cls_labels))
if self.remove_outside_box: if self.remove_outside_box:
labels = labels[labels[:, 0] < target_w] labels = labels[labels[:, 0] < target_w]
labels = labels[labels[:, 2] > 0] labels = labels[labels[:, 2] > 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册