未验证 提交 c612935d 编写于 作者: G Guanghua Yu 提交者: GitHub

Simplify picodet postprocess (#5650)

上级 df4a27c6
worker_num: 6 worker_num: 6
eval_height: &eval_height 320
eval_width: &eval_width 320
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 320, 320] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
worker_num: 6 worker_num: 6
eval_height: &eval_height 416
eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 416, 416] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
worker_num: 6 worker_num: 6
eval_height: &eval_height 640
eval_width: &eval_width 640
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 640, 640] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
worker_num: 6 worker_num: 6
eval_height: &eval_height 320
eval_width: &eval_width 320
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 320, 320] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
worker_num: 6 worker_num: 6
eval_height: &eval_height 416
eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 416, 416] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
worker_num: 6 worker_num: 6
eval_height: &eval_height 640
eval_width: &eval_width 640
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -18,7 +22,7 @@ TrainReader: ...@@ -18,7 +22,7 @@ TrainReader:
EvalReader: EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms: batch_transforms:
...@@ -29,13 +33,10 @@ EvalReader: ...@@ -29,13 +33,10 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [1, 3, 640, 640] image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1 batch_size: 1
shuffle: false
...@@ -67,10 +67,9 @@ class PicoDet(BaseArch): ...@@ -67,10 +67,9 @@ class PicoDet(BaseArch):
if self.training or not self.export_post_process: if self.training or not self.export_post_process:
return head_outs, None return head_outs, None
else: else:
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor'] scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process( bboxes, bbox_num = self.head.post_process(
head_outs, im_shape, scale_factor, export_nms=self.export_nms) head_outs, scale_factor, export_nms=self.export_nms)
return bboxes, bbox_num return bboxes, bbox_num
def get_loss(self, ): def get_loss(self, ):
......
...@@ -79,7 +79,9 @@ class Integral(nn.Layer): ...@@ -79,7 +79,9 @@ class Integral(nn.Layer):
offsets from the box center in four directions, shape (N, 4). offsets from the box center in four directions, shape (N, 4).
""" """
x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1) x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1)
x = F.linear(x, self.project).reshape([-1, 4]) x = F.linear(x, self.project)
if self.training:
x = x.reshape([-1, 4])
return x return x
......
...@@ -194,7 +194,7 @@ class PicoHead(OTAVFLHead): ...@@ -194,7 +194,7 @@ class PicoHead(OTAVFLHead):
'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
'assigner', 'nms' 'assigner', 'nms'
] ]
__shared__ = ['num_classes'] __shared__ = ['num_classes', 'eval_size']
def __init__(self, def __init__(self,
conv_feat='PicoFeat', conv_feat='PicoFeat',
...@@ -210,7 +210,8 @@ class PicoHead(OTAVFLHead): ...@@ -210,7 +210,8 @@ class PicoHead(OTAVFLHead):
feat_in_chan=96, feat_in_chan=96,
nms=None, nms=None,
nms_pre=1000, nms_pre=1000,
cell_offset=0): cell_offset=0,
eval_size=None):
super(PicoHead, self).__init__( super(PicoHead, self).__init__(
conv_feat=conv_feat, conv_feat=conv_feat,
dgqp_module=dgqp_module, dgqp_module=dgqp_module,
...@@ -239,6 +240,7 @@ class PicoHead(OTAVFLHead): ...@@ -239,6 +240,7 @@ class PicoHead(OTAVFLHead):
self.nms = nms self.nms = nms
self.nms_pre = nms_pre self.nms_pre = nms_pre
self.cell_offset = cell_offset self.cell_offset = cell_offset
self.eval_size = eval_size
self.use_sigmoid = self.loss_vfl.use_sigmoid self.use_sigmoid = self.loss_vfl.use_sigmoid
if self.use_sigmoid: if self.use_sigmoid:
...@@ -282,12 +284,50 @@ class PicoHead(OTAVFLHead): ...@@ -282,12 +284,50 @@ class PicoHead(OTAVFLHead):
bias_attr=ParamAttr(initializer=Constant(value=0)))) bias_attr=ParamAttr(initializer=Constant(value=0))))
self.head_reg_list.append(head_reg) self.head_reg_list.append(head_reg)
# initialize the anchor points
if self.eval_size:
self.anchor_points, self.stride_tensor = self._generate_anchors()
def forward(self, fpn_feats, export_post_process=True): def forward(self, fpn_feats, export_post_process=True):
assert len(fpn_feats) == len( assert len(fpn_feats) == len(
self.fpn_stride self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride" ), "The size of fpn_feats is not equal to size of fpn_stride"
cls_logits_list = []
bboxes_reg_list = [] if self.training:
return self.forward_train(fpn_feats)
else:
return self.forward_eval(
fpn_feats, export_post_process=export_post_process)
def forward_train(self, fpn_feats):
cls_logits_list, bboxes_reg_list = [], []
for i, fpn_feat in enumerate(fpn_feats):
conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
if self.conv_feat.share_cls_reg:
cls_logits = self.head_cls_list[i](conv_cls_feat)
cls_score, bbox_pred = paddle.split(
cls_logits,
[self.cls_out_channels, 4 * (self.reg_max + 1)],
axis=1)
else:
cls_score = self.head_cls_list[i](conv_cls_feat)
bbox_pred = self.head_reg_list[i](conv_reg_feat)
if self.dgqp_module:
quality_score = self.dgqp_module(bbox_pred)
cls_score = F.sigmoid(cls_score) * quality_score
cls_logits_list.append(cls_score)
bboxes_reg_list.append(bbox_pred)
return (cls_logits_list, bboxes_reg_list)
def forward_eval(self, fpn_feats, export_post_process=True):
if self.eval_size:
anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
else:
anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
cls_logits_list, bboxes_reg_list = [], []
for i, fpn_feat in enumerate(fpn_feats): for i, fpn_feat in enumerate(fpn_feats):
conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i) conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
if self.conv_feat.share_cls_reg: if self.conv_feat.share_cls_reg:
...@@ -307,50 +347,68 @@ class PicoHead(OTAVFLHead): ...@@ -307,50 +347,68 @@ class PicoHead(OTAVFLHead):
if not export_post_process: if not export_post_process:
# Now only supports batch size = 1 in deploy # Now only supports batch size = 1 in deploy
# TODO(ygh): support batch size > 1 # TODO(ygh): support batch size > 1
cls_score = F.sigmoid(cls_score).reshape( cls_score_out = F.sigmoid(cls_score).reshape(
[1, self.cls_out_channels, -1]).transpose([0, 2, 1]) [1, self.cls_out_channels, -1]).transpose([0, 2, 1])
bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4, bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4,
-1]).transpose([0, 2, 1]) -1]).transpose([0, 2, 1])
elif not self.training: else:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) b, _, h, w = fpn_feat.shape
l = h * w
cls_score_out = F.sigmoid(
cls_score.reshape([b, self.cls_out_channels, l]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
stride = self.fpn_stride[i] bbox_pred = self.distribution_project(bbox_pred)
b, cell_h, cell_w, _ = paddle.shape(cls_score) bbox_pred = bbox_pred.reshape([b, l, 4])
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
# NOTE: If keep_ratio=False and image shape value that
# multiples of 32, distance2bbox not set max_shapes parameter
# to speed up model prediction. If need to set max_shapes,
# please use inputs['im_shape'].
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
cls_logits_list.append(cls_score) cls_logits_list.append(cls_score_out)
bboxes_reg_list.append(bbox_pred) bboxes_reg_list.append(bbox_pred)
if export_post_process:
cls_logits_list = paddle.concat(cls_logits_list, axis=-1)
bboxes_reg_list = paddle.concat(bboxes_reg_list, axis=1)
bboxes_reg_list = batch_distance2bbox(anchor_points,
bboxes_reg_list)
bboxes_reg_list *= stride_tensor
return (cls_logits_list, bboxes_reg_list) return (cls_logits_list, bboxes_reg_list)
def post_process(self, def _generate_anchors(self, feats=None):
gfl_head_outs, # just use in eval time
im_shape, anchor_points = []
scale_factor, stride_tensor = []
export_nms=True): for i, stride in enumerate(self.fpn_stride):
cls_scores, bboxes_reg = gfl_head_outs if feats is not None:
bboxes = paddle.concat(bboxes_reg, axis=1) _, _, h, w = feats[i].shape
mlvl_scores = paddle.concat(cls_scores, axis=1) else:
mlvl_scores = mlvl_scores.transpose([0, 2, 1]) h = math.ceil(self.eval_size[0] / stride)
w = math.ceil(self.eval_size[1] / stride)
shift_x = paddle.arange(end=w) + self.cell_offset
shift_y = paddle.arange(end=h) + self.cell_offset
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor_point = paddle.cast(
paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32')
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[h * w, 1], stride, dtype='float32'))
anchor_points = paddle.concat(anchor_points)
stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor
def post_process(self, head_outs, scale_factor, export_nms=True):
pred_scores, pred_bboxes = head_outs
if not export_nms: if not export_nms:
return bboxes, mlvl_scores return pred_bboxes, pred_scores
else: else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale] # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1) scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
bboxes /= im_scale scale_factor = paddle.concat(
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) [scale_x, scale_y, scale_x, scale_y],
axis=-1).reshape([-1, 1, 4])
# scale bbox to origin image size.
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num return bbox_pred, bbox_num
...@@ -374,29 +432,29 @@ class PicoHeadV2(GFLHead): ...@@ -374,29 +432,29 @@ class PicoHeadV2(GFLHead):
'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
'static_assigner', 'assigner', 'nms' 'static_assigner', 'assigner', 'nms'
] ]
__shared__ = ['num_classes'] __shared__ = ['num_classes', 'eval_size']
def __init__( def __init__(self,
self, conv_feat='PicoFeatV2',
conv_feat='PicoFeatV2', dgqp_module=None,
dgqp_module=None, num_classes=80,
num_classes=80, fpn_stride=[8, 16, 32],
fpn_stride=[8, 16, 32], prior_prob=0.01,
prior_prob=0.01, use_align_head=True,
use_align_head=True, loss_class='VariFocalLoss',
loss_class='VariFocalLoss', loss_dfl='DistributionFocalLoss',
loss_dfl='DistributionFocalLoss', loss_bbox='GIoULoss',
loss_bbox='GIoULoss', static_assigner_epoch=60,
static_assigner_epoch=60, static_assigner='ATSSAssigner',
static_assigner='ATSSAssigner', assigner='TaskAlignedAssigner',
assigner='TaskAlignedAssigner', reg_max=16,
reg_max=16, feat_in_chan=96,
feat_in_chan=96, nms=None,
nms=None, nms_pre=1000,
nms_pre=1000, cell_offset=0,
cell_offset=0, act='hard_swish',
act='hard_swish', grid_cell_scale=5.0,
grid_cell_scale=5.0, ): eval_size=None):
super(PicoHeadV2, self).__init__( super(PicoHeadV2, self).__init__(
conv_feat=conv_feat, conv_feat=conv_feat,
dgqp_module=dgqp_module, dgqp_module=dgqp_module,
...@@ -432,6 +490,7 @@ class PicoHeadV2(GFLHead): ...@@ -432,6 +490,7 @@ class PicoHeadV2(GFLHead):
self.grid_cell_scale = grid_cell_scale self.grid_cell_scale = grid_cell_scale
self.use_align_head = use_align_head self.use_align_head = use_align_head
self.cls_out_channels = self.num_classes self.cls_out_channels = self.num_classes
self.eval_size = eval_size
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
# Clear the super class initialization # Clear the super class initialization
...@@ -478,11 +537,22 @@ class PicoHeadV2(GFLHead): ...@@ -478,11 +537,22 @@ class PicoHeadV2(GFLHead):
act=self.act, act=self.act,
use_act_in_out=False)) use_act_in_out=False))
# initialize the anchor points
if self.eval_size:
self.anchor_points, self.stride_tensor = self._generate_anchors()
def forward(self, fpn_feats, export_post_process=True): def forward(self, fpn_feats, export_post_process=True):
assert len(fpn_feats) == len( assert len(fpn_feats) == len(
self.fpn_stride self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride" ), "The size of fpn_feats is not equal to size of fpn_stride"
if self.training:
return self.forward_train(fpn_feats)
else:
return self.forward_eval(
fpn_feats, export_post_process=export_post_process)
def forward_train(self, fpn_feats):
cls_score_list, reg_list, box_list = [], [], [] cls_score_list, reg_list, box_list = [], [], []
for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)): for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
b, _, h, w = get_static_shape(fpn_feat) b, _, h, w = get_static_shape(fpn_feat)
...@@ -498,7 +568,48 @@ class PicoHeadV2(GFLHead): ...@@ -498,7 +568,48 @@ class PicoHeadV2(GFLHead):
else: else:
cls_score = F.sigmoid(cls_logit) cls_score = F.sigmoid(cls_logit)
if not export_post_process and not self.training: cls_score_out = cls_score.transpose([0, 2, 3, 1])
bbox_pred = reg_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = paddle.shape(cls_score_out)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
cls_score_out = cls_score_out.reshape(
[b, -1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1]))
box_list.append(bbox_pred / stride)
cls_score_list = paddle.concat(cls_score_list, axis=1)
box_list = paddle.concat(box_list, axis=1)
reg_list = paddle.concat(reg_list, axis=1)
return cls_score_list, reg_list, box_list, fpn_feats
def forward_eval(self, fpn_feats, export_post_process=True):
if self.eval_size:
anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
else:
anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
cls_score_list, box_list = [], []
for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
b, _, h, w = fpn_feat.shape
# task decomposition
conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
cls_logit = self.head_cls_list[i](se_feat)
reg_pred = self.head_reg_list[i](se_feat)
# cls prediction and alignment
if self.use_align_head:
cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
else:
cls_score = F.sigmoid(cls_logit)
if not export_post_process:
# Now only supports batch size = 1 in deploy # Now only supports batch size = 1 in deploy
cls_score_list.append( cls_score_list.append(
cls_score.reshape([1, self.cls_out_channels, -1]).transpose( cls_score.reshape([1, self.cls_out_channels, -1]).transpose(
...@@ -507,34 +618,21 @@ class PicoHeadV2(GFLHead): ...@@ -507,34 +618,21 @@ class PicoHeadV2(GFLHead):
reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose( reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose(
[0, 2, 1])) [0, 2, 1]))
else: else:
cls_score_out = cls_score.transpose([0, 2, 3, 1]) l = h * w
cls_score_out = cls_score.reshape([b, self.cls_out_channels, l])
bbox_pred = reg_pred.transpose([0, 2, 3, 1]) bbox_pred = reg_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = paddle.shape(cls_score_out) bbox_pred = self.distribution_project(bbox_pred)
y, x = self.get_single_level_center_point( bbox_pred = bbox_pred.reshape([b, l, 4])
[cell_h, cell_w], stride, cell_offset=self.cell_offset) cls_score_list.append(cls_score_out)
center_points = paddle.stack([x, y], axis=-1) box_list.append(bbox_pred)
cls_score_out = cls_score_out.reshape(
[b, -1, self.cls_out_channels]) if export_post_process:
bbox_pred = self.distribution_project(bbox_pred) * stride cls_score_list = paddle.concat(cls_score_list, axis=-1)
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
if not self.training:
cls_score_list.append(cls_score_out)
box_list.append(bbox_pred)
else:
cls_score_list.append(
cls_score.flatten(2).transpose([0, 2, 1]))
reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1]))
box_list.append(bbox_pred / stride)
if not self.training:
return cls_score_list, box_list
else:
cls_score_list = paddle.concat(cls_score_list, axis=1)
box_list = paddle.concat(box_list, axis=1) box_list = paddle.concat(box_list, axis=1)
reg_list = paddle.concat(reg_list, axis=1) box_list = batch_distance2bbox(anchor_points, box_list)
return cls_score_list, reg_list, box_list, fpn_feats box_list *= stride_tensor
return cls_score_list, box_list
def get_loss(self, head_outs, gt_meta): def get_loss(self, head_outs, gt_meta):
pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs
...@@ -644,20 +742,41 @@ class PicoHeadV2(GFLHead): ...@@ -644,20 +742,41 @@ class PicoHeadV2(GFLHead):
return loss_states return loss_states
def post_process(self, def _generate_anchors(self, feats=None):
gfl_head_outs, # just use in eval time
im_shape, anchor_points = []
scale_factor, stride_tensor = []
export_nms=True): for i, stride in enumerate(self.fpn_stride):
cls_scores, bboxes_reg = gfl_head_outs if feats is not None:
bboxes = paddle.concat(bboxes_reg, axis=1) _, _, h, w = feats[i].shape
mlvl_scores = paddle.concat(cls_scores, axis=1) else:
mlvl_scores = mlvl_scores.transpose([0, 2, 1]) h = math.ceil(self.eval_size[0] / stride)
w = math.ceil(self.eval_size[1] / stride)
shift_x = paddle.arange(end=w) + self.cell_offset
shift_y = paddle.arange(end=h) + self.cell_offset
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor_point = paddle.cast(
paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32')
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[h * w, 1], stride, dtype='float32'))
anchor_points = paddle.concat(anchor_points)
stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor
def post_process(self, head_outs, scale_factor, export_nms=True):
pred_scores, pred_bboxes = head_outs
if not export_nms: if not export_nms:
return bboxes, mlvl_scores return pred_bboxes, pred_scores
else: else:
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale] # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1) scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
bboxes /= im_scale scale_factor = paddle.concat(
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) [scale_x, scale_y, scale_x, scale_y],
axis=-1).reshape([-1, 1, 4])
# scale bbox to origin image size.
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册