未验证 提交 9c680e5b 编写于 作者: G Guanghua Yu 提交者: GitHub

fix is_crowd and difficult in Mosaic (#6150)

上级 18c2099a
......@@ -3184,7 +3184,7 @@ class Mosaic(BaseOperator):
if np.random.uniform(0., 1.) > self.prob:
return sample[0]
mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], []
mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd, mosaic_difficult = [], [], [], []
input_h, input_w = self.input_dim
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
......@@ -3217,21 +3217,35 @@ class Mosaic(BaseOperator):
_gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw
_gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh
is_crowd = sp['is_crowd'] if 'is_crowd' in sp else np.zeros(
(len(_gt_bbox), 1), dtype=np.int32)
mosaic_gt_bbox.append(_gt_bbox)
mosaic_gt_class.append(sp['gt_class'])
mosaic_is_crowd.append(is_crowd)
if 'is_crowd' in sp:
mosaic_is_crowd.append(sp['is_crowd'])
if 'difficult' in sp:
mosaic_difficult.append(sp['difficult'])
# 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd])
if len(mosaic_gt_bbox):
mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0)
mosaic_gt_class = np.concatenate(mosaic_gt_class, 0)
mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype)
], 1)
if mosaic_is_crowd:
mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox,
mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype)
], 1)
elif mosaic_difficult:
mosaic_difficult = np.concatenate(mosaic_difficult, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox,
mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_difficult.astype(mosaic_gt_bbox.dtype)
], 1)
else:
mosaic_labels = np.concatenate([
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype)
], 1)
if self.remove_outside_box:
# for MOT dataset
flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w
......@@ -3268,11 +3282,23 @@ class Mosaic(BaseOperator):
random.random() < self.mixup_prob):
sample_mixup = sample[4]
mixup_img = sample_mixup['image']
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['is_crowd'].astype(mosaic_labels.dtype)
], 1)
if 'is_crowd' in sample_mixup:
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['is_crowd'].astype(mosaic_labels.dtype)
], 1)
elif 'difficult' in sample_mixup:
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['difficult'].astype(mosaic_labels.dtype)
], 1)
else:
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype)
], 1)
mosaic_img, mosaic_labels = self.mixup_augment(
mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img)
......@@ -3284,7 +3310,10 @@ class Mosaic(BaseOperator):
sample0['im_shape'][1] = sample0['w']
sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32)
sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32)
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32)
if 'is_crowd' in sample[0]:
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32)
if 'difficult' in sample[0]:
sample0['difficult'] = mosaic_labels[:, 5:6].astype(np.float32)
return sample0
def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels,
......@@ -3351,9 +3380,12 @@ class Mosaic(BaseOperator):
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h)
cls_labels = cp_labels[:, 4:5].copy()
crd_labels = cp_labels[:, 5:6].copy()
box_labels = cp_bboxes_transformed_np
labels = np.hstack((box_labels, cls_labels, crd_labels))
if cp_labels.shape[-1] == 6:
crd_labels = cp_labels[:, 5:6].copy()
labels = np.hstack((box_labels, cls_labels, crd_labels))
else:
labels = np.hstack((box_labels, cls_labels))
if self.remove_outside_box:
labels = labels[labels[:, 0] < target_w]
labels = labels[labels[:, 2] > 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册