未验证 提交 8c2af365 编写于 作者: H huangjun12 提交者: GitHub

only compute vlr region when distill (#7437)

上级 52645f8d
......@@ -15,3 +15,11 @@ ResNet:
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
TrainReader:
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2GFLTarget:
downsample_ratios: [8, 16, 32, 64, 128]
grid_cell_scale: 8
compute_vlr_region: True
\ No newline at end of file
......@@ -500,12 +500,14 @@ class Gt2GFLTarget(BaseOperator):
num_classes=80,
downsample_ratios=[8, 16, 32, 64, 128],
grid_cell_scale=4,
cell_offset=0):
cell_offset=0,
compute_vlr_region=False):
super(Gt2GFLTarget, self).__init__()
self.num_classes = num_classes
self.downsample_ratios = downsample_ratios
self.grid_cell_scale = grid_cell_scale
self.cell_offset = cell_offset
self.compute_vlr_region = compute_vlr_region
self.assigner = ATSSAssigner()
......@@ -585,9 +587,11 @@ class Gt2GFLTarget(BaseOperator):
gt_bboxes, gt_bboxes_ignore,
gt_labels)
vlr_region = self.assigner.get_vlr_region(
grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore,
gt_labels)
if self.compute_vlr_region:
vlr_region = self.assigner.get_vlr_region(
grid_cells, num_level_cells, gt_bboxes, gt_bboxes_ignore,
gt_labels)
sample['vlr_regions'] = vlr_region
pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample(
assign_gt_inds, gt_bboxes)
......@@ -615,7 +619,6 @@ class Gt2GFLTarget(BaseOperator):
sample['label_weights'] = label_weights
sample['bbox_targets'] = bbox_targets
sample['pos_num'] = max(pos_inds.size, 1)
sample['vlr_regions'] = vlr_region
sample.pop('is_crowd', None)
sample.pop('difficult', None)
sample.pop('gt_class', None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册