From c612935d8d7431f3a730cf5e213159f6b20938d1 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 12 Apr 2022 11:58:14 +0800 Subject: [PATCH] Simplify picodet postprocess (#5650) --- configs/picodet/_base_/picodet_320_reader.yml | 13 +- configs/picodet/_base_/picodet_416_reader.yml | 13 +- configs/picodet/_base_/picodet_640_reader.yml | 13 +- .../_base_/picodet_320_reader.yml | 13 +- .../_base_/picodet_416_reader.yml | 13 +- .../_base_/picodet_640_reader.yml | 13 +- ppdet/modeling/architectures/picodet.py | 3 +- ppdet/modeling/heads/gfl_head.py | 4 +- ppdet/modeling/heads/pico_head.py | 317 ++++++++++++------ 9 files changed, 264 insertions(+), 138 deletions(-) diff --git a/configs/picodet/_base_/picodet_320_reader.yml b/configs/picodet/_base_/picodet_320_reader.yml index 6b0112469..7d6500679 100644 --- a/configs/picodet/_base_/picodet_320_reader.yml +++ b/configs/picodet/_base_/picodet_320_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 320 +eval_width: &eval_width 320 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 320, 320] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/picodet/_base_/picodet_416_reader.yml b/configs/picodet/_base_/picodet_416_reader.yml index f98fe08e1..ee4ae9886 100644 --- a/configs/picodet/_base_/picodet_416_reader.yml +++ b/configs/picodet/_base_/picodet_416_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 416 +eval_width: &eval_width 416 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 416, 416] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/picodet/_base_/picodet_640_reader.yml b/configs/picodet/_base_/picodet_640_reader.yml index d90fbeb97..5502026af 100644 --- a/configs/picodet/_base_/picodet_640_reader.yml +++ b/configs/picodet/_base_/picodet_640_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 640 +eval_width: &eval_width 640 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 640, 640] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/picodet/legacy_model/_base_/picodet_320_reader.yml b/configs/picodet/legacy_model/_base_/picodet_320_reader.yml index 2ce5bca66..4d3f0cbd8 100644 --- a/configs/picodet/legacy_model/_base_/picodet_320_reader.yml +++ b/configs/picodet/legacy_model/_base_/picodet_320_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 320 +eval_width: &eval_width 320 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 320, 320] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/picodet/legacy_model/_base_/picodet_416_reader.yml b/configs/picodet/legacy_model/_base_/picodet_416_reader.yml index 12070a4be..59433c645 100644 --- a/configs/picodet/legacy_model/_base_/picodet_416_reader.yml +++ b/configs/picodet/legacy_model/_base_/picodet_416_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 416 +eval_width: &eval_width 416 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 416, 416] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/picodet/legacy_model/_base_/picodet_640_reader.yml b/configs/picodet/legacy_model/_base_/picodet_640_reader.yml index a931f2a76..60904fb6b 100644 --- a/configs/picodet/legacy_model/_base_/picodet_640_reader.yml +++ b/configs/picodet/legacy_model/_base_/picodet_640_reader.yml @@ -1,4 +1,8 @@ worker_num: 6 +eval_height: &eval_height 640 +eval_width: &eval_width 640 +eval_size: &eval_size [*eval_height, *eval_width] + TrainReader: sample_transforms: - Decode: {} @@ -18,7 +22,7 @@ TrainReader: EvalReader: sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: @@ -29,13 +33,10 @@ EvalReader: TestReader: inputs_def: - image_shape: [1, 3, 640, 640] + image_shape: [1, 3, *eval_height, *eval_width] sample_transforms: - Decode: {} - - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/ppdet/modeling/architectures/picodet.py b/ppdet/modeling/architectures/picodet.py index 760b8347b..0b87a4baa 100644 --- a/ppdet/modeling/architectures/picodet.py +++ b/ppdet/modeling/architectures/picodet.py @@ -67,10 +67,9 @@ class PicoDet(BaseArch): if self.training or not self.export_post_process: return head_outs, None else: - im_shape = self.inputs['im_shape'] scale_factor = self.inputs['scale_factor'] bboxes, bbox_num = self.head.post_process( - head_outs, im_shape, scale_factor, export_nms=self.export_nms) + head_outs, scale_factor, export_nms=self.export_nms) return bboxes, bbox_num def get_loss(self, ): diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index 779d739b8..654c84fce 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -79,7 +79,9 @@ class Integral(nn.Layer): offsets from the box center in four directions, shape (N, 4). """ x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1) - x = F.linear(x, self.project).reshape([-1, 4]) + x = F.linear(x, self.project) + if self.training: + x = x.reshape([-1, 4]) return x diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index 98c8c8ef9..ecb4b9764 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -194,7 +194,7 @@ class PicoHead(OTAVFLHead): 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'assigner', 'nms' ] - __shared__ = ['num_classes'] + __shared__ = ['num_classes', 'eval_size'] def __init__(self, conv_feat='PicoFeat', @@ -210,7 +210,8 @@ class PicoHead(OTAVFLHead): feat_in_chan=96, nms=None, nms_pre=1000, - cell_offset=0): + cell_offset=0, + eval_size=None): super(PicoHead, self).__init__( conv_feat=conv_feat, dgqp_module=dgqp_module, @@ -239,6 +240,7 @@ class PicoHead(OTAVFLHead): self.nms = nms self.nms_pre = nms_pre self.cell_offset = cell_offset + self.eval_size = eval_size self.use_sigmoid = self.loss_vfl.use_sigmoid if self.use_sigmoid: @@ -282,12 +284,50 @@ class PicoHead(OTAVFLHead): bias_attr=ParamAttr(initializer=Constant(value=0)))) self.head_reg_list.append(head_reg) + # initialize the anchor points + if self.eval_size: + self.anchor_points, self.stride_tensor = self._generate_anchors() + def forward(self, fpn_feats, export_post_process=True): assert len(fpn_feats) == len( self.fpn_stride ), "The size of fpn_feats is not equal to size of fpn_stride" - cls_logits_list = [] - bboxes_reg_list = [] + + if self.training: + return self.forward_train(fpn_feats) + else: + return self.forward_eval( + fpn_feats, export_post_process=export_post_process) + + def forward_train(self, fpn_feats): + cls_logits_list, bboxes_reg_list = [], [] + for i, fpn_feat in enumerate(fpn_feats): + conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i) + if self.conv_feat.share_cls_reg: + cls_logits = self.head_cls_list[i](conv_cls_feat) + cls_score, bbox_pred = paddle.split( + cls_logits, + [self.cls_out_channels, 4 * (self.reg_max + 1)], + axis=1) + else: + cls_score = self.head_cls_list[i](conv_cls_feat) + bbox_pred = self.head_reg_list[i](conv_reg_feat) + + if self.dgqp_module: + quality_score = self.dgqp_module(bbox_pred) + cls_score = F.sigmoid(cls_score) * quality_score + + cls_logits_list.append(cls_score) + bboxes_reg_list.append(bbox_pred) + + return (cls_logits_list, bboxes_reg_list) + + def forward_eval(self, fpn_feats, export_post_process=True): + if self.eval_size: + anchor_points, stride_tensor = self.anchor_points, self.stride_tensor + else: + anchor_points, stride_tensor = self._generate_anchors(fpn_feats) + cls_logits_list, bboxes_reg_list = [], [] for i, fpn_feat in enumerate(fpn_feats): conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i) if self.conv_feat.share_cls_reg: @@ -307,50 +347,68 @@ class PicoHead(OTAVFLHead): if not export_post_process: # Now only supports batch size = 1 in deploy # TODO(ygh): support batch size > 1 - cls_score = F.sigmoid(cls_score).reshape( + cls_score_out = F.sigmoid(cls_score).reshape( [1, self.cls_out_channels, -1]).transpose([0, 2, 1]) bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose([0, 2, 1]) - elif not self.training: - cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) + else: + b, _, h, w = fpn_feat.shape + l = h * w + cls_score_out = F.sigmoid( + cls_score.reshape([b, self.cls_out_channels, l])) bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) - stride = self.fpn_stride[i] - b, cell_h, cell_w, _ = paddle.shape(cls_score) - y, x = self.get_single_level_center_point( - [cell_h, cell_w], stride, cell_offset=self.cell_offset) - center_points = paddle.stack([x, y], axis=-1) - cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) - bbox_pred = self.distribution_project(bbox_pred) * stride - bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) - - # NOTE: If keep_ratio=False and image shape value that - # multiples of 32, distance2bbox not set max_shapes parameter - # to speed up model prediction. If need to set max_shapes, - # please use inputs['im_shape']. - bbox_pred = batch_distance2bbox( - center_points, bbox_pred, max_shapes=None) + bbox_pred = self.distribution_project(bbox_pred) + bbox_pred = bbox_pred.reshape([b, l, 4]) - cls_logits_list.append(cls_score) + cls_logits_list.append(cls_score_out) bboxes_reg_list.append(bbox_pred) + if export_post_process: + cls_logits_list = paddle.concat(cls_logits_list, axis=-1) + bboxes_reg_list = paddle.concat(bboxes_reg_list, axis=1) + bboxes_reg_list = batch_distance2bbox(anchor_points, + bboxes_reg_list) + bboxes_reg_list *= stride_tensor + return (cls_logits_list, bboxes_reg_list) - def post_process(self, - gfl_head_outs, - im_shape, - scale_factor, - export_nms=True): - cls_scores, bboxes_reg = gfl_head_outs - bboxes = paddle.concat(bboxes_reg, axis=1) - mlvl_scores = paddle.concat(cls_scores, axis=1) - mlvl_scores = mlvl_scores.transpose([0, 2, 1]) + def _generate_anchors(self, feats=None): + # just use in eval time + anchor_points = [] + stride_tensor = [] + for i, stride in enumerate(self.fpn_stride): + if feats is not None: + _, _, h, w = feats[i].shape + else: + h = math.ceil(self.eval_size[0] / stride) + w = math.ceil(self.eval_size[1] / stride) + shift_x = paddle.arange(end=w) + self.cell_offset + shift_y = paddle.arange(end=h) + self.cell_offset + shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) + anchor_point = paddle.cast( + paddle.stack( + [shift_x, shift_y], axis=-1), dtype='float32') + anchor_points.append(anchor_point.reshape([-1, 2])) + stride_tensor.append( + paddle.full( + [h * w, 1], stride, dtype='float32')) + anchor_points = paddle.concat(anchor_points) + stride_tensor = paddle.concat(stride_tensor) + return anchor_points, stride_tensor + + def post_process(self, head_outs, scale_factor, export_nms=True): + pred_scores, pred_bboxes = head_outs if not export_nms: - return bboxes, mlvl_scores + return pred_bboxes, pred_scores else: # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale] - im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1) - bboxes /= im_scale - bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) + scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) + scale_factor = paddle.concat( + [scale_x, scale_y, scale_x, scale_y], + axis=-1).reshape([-1, 1, 4]) + # scale bbox to origin image size. + pred_bboxes /= scale_factor + bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) return bbox_pred, bbox_num @@ -374,29 +432,29 @@ class PicoHeadV2(GFLHead): 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'static_assigner', 'assigner', 'nms' ] - __shared__ = ['num_classes'] - - def __init__( - self, - conv_feat='PicoFeatV2', - dgqp_module=None, - num_classes=80, - fpn_stride=[8, 16, 32], - prior_prob=0.01, - use_align_head=True, - loss_class='VariFocalLoss', - loss_dfl='DistributionFocalLoss', - loss_bbox='GIoULoss', - static_assigner_epoch=60, - static_assigner='ATSSAssigner', - assigner='TaskAlignedAssigner', - reg_max=16, - feat_in_chan=96, - nms=None, - nms_pre=1000, - cell_offset=0, - act='hard_swish', - grid_cell_scale=5.0, ): + __shared__ = ['num_classes', 'eval_size'] + + def __init__(self, + conv_feat='PicoFeatV2', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32], + prior_prob=0.01, + use_align_head=True, + loss_class='VariFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + static_assigner_epoch=60, + static_assigner='ATSSAssigner', + assigner='TaskAlignedAssigner', + reg_max=16, + feat_in_chan=96, + nms=None, + nms_pre=1000, + cell_offset=0, + act='hard_swish', + grid_cell_scale=5.0, + eval_size=None): super(PicoHeadV2, self).__init__( conv_feat=conv_feat, dgqp_module=dgqp_module, @@ -432,6 +490,7 @@ class PicoHeadV2(GFLHead): self.grid_cell_scale = grid_cell_scale self.use_align_head = use_align_head self.cls_out_channels = self.num_classes + self.eval_size = eval_size bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) # Clear the super class initialization @@ -478,11 +537,22 @@ class PicoHeadV2(GFLHead): act=self.act, use_act_in_out=False)) + # initialize the anchor points + if self.eval_size: + self.anchor_points, self.stride_tensor = self._generate_anchors() + def forward(self, fpn_feats, export_post_process=True): assert len(fpn_feats) == len( self.fpn_stride ), "The size of fpn_feats is not equal to size of fpn_stride" + if self.training: + return self.forward_train(fpn_feats) + else: + return self.forward_eval( + fpn_feats, export_post_process=export_post_process) + + def forward_train(self, fpn_feats): cls_score_list, reg_list, box_list = [], [], [] for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)): b, _, h, w = get_static_shape(fpn_feat) @@ -498,7 +568,48 @@ class PicoHeadV2(GFLHead): else: cls_score = F.sigmoid(cls_logit) - if not export_post_process and not self.training: + cls_score_out = cls_score.transpose([0, 2, 3, 1]) + bbox_pred = reg_pred.transpose([0, 2, 3, 1]) + b, cell_h, cell_w, _ = paddle.shape(cls_score_out) + y, x = self.get_single_level_center_point( + [cell_h, cell_w], stride, cell_offset=self.cell_offset) + center_points = paddle.stack([x, y], axis=-1) + cls_score_out = cls_score_out.reshape( + [b, -1, self.cls_out_channels]) + bbox_pred = self.distribution_project(bbox_pred) * stride + bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) + bbox_pred = batch_distance2bbox( + center_points, bbox_pred, max_shapes=None) + cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) + reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1])) + box_list.append(bbox_pred / stride) + + cls_score_list = paddle.concat(cls_score_list, axis=1) + box_list = paddle.concat(box_list, axis=1) + reg_list = paddle.concat(reg_list, axis=1) + return cls_score_list, reg_list, box_list, fpn_feats + + def forward_eval(self, fpn_feats, export_post_process=True): + if self.eval_size: + anchor_points, stride_tensor = self.anchor_points, self.stride_tensor + else: + anchor_points, stride_tensor = self._generate_anchors(fpn_feats) + cls_score_list, box_list = [], [] + for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)): + b, _, h, w = fpn_feat.shape + # task decomposition + conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i) + cls_logit = self.head_cls_list[i](se_feat) + reg_pred = self.head_reg_list[i](se_feat) + + # cls prediction and alignment + if self.use_align_head: + cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat)) + cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt() + else: + cls_score = F.sigmoid(cls_logit) + + if not export_post_process: # Now only supports batch size = 1 in deploy cls_score_list.append( cls_score.reshape([1, self.cls_out_channels, -1]).transpose( @@ -507,34 +618,21 @@ class PicoHeadV2(GFLHead): reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose( [0, 2, 1])) else: - cls_score_out = cls_score.transpose([0, 2, 3, 1]) + l = h * w + cls_score_out = cls_score.reshape([b, self.cls_out_channels, l]) bbox_pred = reg_pred.transpose([0, 2, 3, 1]) - b, cell_h, cell_w, _ = paddle.shape(cls_score_out) - y, x = self.get_single_level_center_point( - [cell_h, cell_w], stride, cell_offset=self.cell_offset) - center_points = paddle.stack([x, y], axis=-1) - cls_score_out = cls_score_out.reshape( - [b, -1, self.cls_out_channels]) - bbox_pred = self.distribution_project(bbox_pred) * stride - bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) - bbox_pred = batch_distance2bbox( - center_points, bbox_pred, max_shapes=None) - if not self.training: - cls_score_list.append(cls_score_out) - box_list.append(bbox_pred) - else: - cls_score_list.append( - cls_score.flatten(2).transpose([0, 2, 1])) - reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1])) - box_list.append(bbox_pred / stride) - - if not self.training: - return cls_score_list, box_list - else: - cls_score_list = paddle.concat(cls_score_list, axis=1) + bbox_pred = self.distribution_project(bbox_pred) + bbox_pred = bbox_pred.reshape([b, l, 4]) + cls_score_list.append(cls_score_out) + box_list.append(bbox_pred) + + if export_post_process: + cls_score_list = paddle.concat(cls_score_list, axis=-1) box_list = paddle.concat(box_list, axis=1) - reg_list = paddle.concat(reg_list, axis=1) - return cls_score_list, reg_list, box_list, fpn_feats + box_list = batch_distance2bbox(anchor_points, box_list) + box_list *= stride_tensor + + return cls_score_list, box_list def get_loss(self, head_outs, gt_meta): pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs @@ -644,20 +742,41 @@ class PicoHeadV2(GFLHead): return loss_states - def post_process(self, - gfl_head_outs, - im_shape, - scale_factor, - export_nms=True): - cls_scores, bboxes_reg = gfl_head_outs - bboxes = paddle.concat(bboxes_reg, axis=1) - mlvl_scores = paddle.concat(cls_scores, axis=1) - mlvl_scores = mlvl_scores.transpose([0, 2, 1]) + def _generate_anchors(self, feats=None): + # just use in eval time + anchor_points = [] + stride_tensor = [] + for i, stride in enumerate(self.fpn_stride): + if feats is not None: + _, _, h, w = feats[i].shape + else: + h = math.ceil(self.eval_size[0] / stride) + w = math.ceil(self.eval_size[1] / stride) + shift_x = paddle.arange(end=w) + self.cell_offset + shift_y = paddle.arange(end=h) + self.cell_offset + shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) + anchor_point = paddle.cast( + paddle.stack( + [shift_x, shift_y], axis=-1), dtype='float32') + anchor_points.append(anchor_point.reshape([-1, 2])) + stride_tensor.append( + paddle.full( + [h * w, 1], stride, dtype='float32')) + anchor_points = paddle.concat(anchor_points) + stride_tensor = paddle.concat(stride_tensor) + return anchor_points, stride_tensor + + def post_process(self, head_outs, scale_factor, export_nms=True): + pred_scores, pred_bboxes = head_outs if not export_nms: - return bboxes, mlvl_scores + return pred_bboxes, pred_scores else: # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale] - im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1) - bboxes /= im_scale - bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) + scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) + scale_factor = paddle.concat( + [scale_x, scale_y, scale_x, scale_y], + axis=-1).reshape([-1, 1, 4]) + # scale bbox to origin image size. + pred_bboxes /= scale_factor + bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) return bbox_pred, bbox_num -- GitLab