From 9c680e5b42d89d158ad17bd92a5956141c41dd62 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 7 Jun 2022 21:57:29 +0800 Subject: [PATCH] fix is_crowd and difficult in Mosaic (#6150) --- ppdet/data/transform/operators.py | 66 +++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 052e35101..09a87b128 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -3184,7 +3184,7 @@ class Mosaic(BaseOperator): if np.random.uniform(0., 1.) > self.prob: return sample[0] - mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], [] + mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd, mosaic_difficult = [], [], [], [] input_h, input_w = self.input_dim yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) @@ -3217,21 +3217,35 @@ class Mosaic(BaseOperator): _gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw _gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh - is_crowd = sp['is_crowd'] if 'is_crowd' in sp else np.zeros( - (len(_gt_bbox), 1), dtype=np.int32) mosaic_gt_bbox.append(_gt_bbox) mosaic_gt_class.append(sp['gt_class']) - mosaic_is_crowd.append(is_crowd) + if 'is_crowd' in sp: + mosaic_is_crowd.append(sp['is_crowd']) + if 'difficult' in sp: + mosaic_difficult.append(sp['difficult']) # 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd]) if len(mosaic_gt_bbox): mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0) mosaic_gt_class = np.concatenate(mosaic_gt_class, 0) - mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0) - mosaic_labels = np.concatenate([ - mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype), - mosaic_is_crowd.astype(mosaic_gt_bbox.dtype) - ], 1) + if mosaic_is_crowd: + mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0) + mosaic_labels = np.concatenate([ + mosaic_gt_bbox, + mosaic_gt_class.astype(mosaic_gt_bbox.dtype), + mosaic_is_crowd.astype(mosaic_gt_bbox.dtype) + ], 1) + elif mosaic_difficult: + mosaic_difficult = np.concatenate(mosaic_difficult, 0) + mosaic_labels = np.concatenate([ + mosaic_gt_bbox, + mosaic_gt_class.astype(mosaic_gt_bbox.dtype), + mosaic_difficult.astype(mosaic_gt_bbox.dtype) + ], 1) + else: + mosaic_labels = np.concatenate([ + mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype) + ], 1) if self.remove_outside_box: # for MOT dataset flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w @@ -3268,11 +3282,23 @@ class Mosaic(BaseOperator): random.random() < self.mixup_prob): sample_mixup = sample[4] mixup_img = sample_mixup['image'] - cp_labels = np.concatenate([ - sample_mixup['gt_bbox'], - sample_mixup['gt_class'].astype(mosaic_labels.dtype), - sample_mixup['is_crowd'].astype(mosaic_labels.dtype) - ], 1) + if 'is_crowd' in sample_mixup: + cp_labels = np.concatenate([ + sample_mixup['gt_bbox'], + sample_mixup['gt_class'].astype(mosaic_labels.dtype), + sample_mixup['is_crowd'].astype(mosaic_labels.dtype) + ], 1) + elif 'difficult' in sample_mixup: + cp_labels = np.concatenate([ + sample_mixup['gt_bbox'], + sample_mixup['gt_class'].astype(mosaic_labels.dtype), + sample_mixup['difficult'].astype(mosaic_labels.dtype) + ], 1) + else: + cp_labels = np.concatenate([ + sample_mixup['gt_bbox'], + sample_mixup['gt_class'].astype(mosaic_labels.dtype) + ], 1) mosaic_img, mosaic_labels = self.mixup_augment( mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img) @@ -3284,7 +3310,10 @@ class Mosaic(BaseOperator): sample0['im_shape'][1] = sample0['w'] sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32) sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32) - sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32) + if 'is_crowd' in sample[0]: + sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32) + if 'difficult' in sample[0]: + sample0['difficult'] = mosaic_labels[:, 5:6].astype(np.float32) return sample0 def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels, @@ -3351,9 +3380,12 @@ class Mosaic(BaseOperator): cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h) cls_labels = cp_labels[:, 4:5].copy() - crd_labels = cp_labels[:, 5:6].copy() box_labels = cp_bboxes_transformed_np - labels = np.hstack((box_labels, cls_labels, crd_labels)) + if cp_labels.shape[-1] == 6: + crd_labels = cp_labels[:, 5:6].copy() + labels = np.hstack((box_labels, cls_labels, crd_labels)) + else: + labels = np.hstack((box_labels, cls_labels)) if self.remove_outside_box: labels = labels[labels[:, 0] < target_w] labels = labels[labels[:, 2] > 0] -- GitLab