diff --git a/configs/slim/distill/gfl_ld_distill.yml b/configs/slim/distill/gfl_ld_distill.yml index ae0ddba9c318bf24195fd23810dd65b0514bc6ee..2601e99f319e089d34caf912495c87a8fe0fd98c 100644 --- a/configs/slim/distill/gfl_ld_distill.yml +++ b/configs/slim/distill/gfl_ld_distill.yml @@ -15,3 +15,11 @@ ResNet: freeze_at: 0 return_idx: [1,2,3] num_stages: 4 + +TrainReader: + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2GFLTarget: + downsample_ratios: [8, 16, 32, 64, 128] + grid_cell_scale: 8 + compute_vlr_region: True \ No newline at end of file diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 0c48ffbd00816e273360c7f49b17a7dd53614144..f111adda45bef1265af76e1b5c5316e47f64f966 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -500,12 +500,14 @@ class Gt2GFLTarget(BaseOperator): num_classes=80, downsample_ratios=[8, 16, 32, 64, 128], grid_cell_scale=4, - cell_offset=0): + cell_offset=0, + compute_vlr_region=False): super(Gt2GFLTarget, self).__init__() self.num_classes = num_classes self.downsample_ratios = downsample_ratios self.grid_cell_scale = grid_cell_scale self.cell_offset = cell_offset + self.compute_vlr_region = compute_vlr_region self.assigner = ATSSAssigner() @@ -585,9 +587,11 @@ class Gt2GFLTarget(BaseOperator): gt_bboxes, gt_bboxes_ignore, gt_labels) - vlr_region = self.assigner.get_vlr_region( - grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, - gt_labels) + if self.compute_vlr_region: + vlr_region = self.assigner.get_vlr_region( + grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore, + gt_labels) + sample['vlr_regions'] = vlr_region pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample( assign_gt_inds, gt_bboxes) @@ -615,7 +619,6 @@ class Gt2GFLTarget(BaseOperator): sample['label_weights'] = label_weights sample['bbox_targets'] = bbox_targets sample['pos_num'] = max(pos_inds.size, 1) - sample['vlr_regions'] = vlr_region sample.pop('is_crowd', None) sample.pop('difficult', None) sample.pop('gt_class', None)