未验证 提交 b352ef88 编写于 作者: L LokeZhou 提交者: GitHub

add picodet ppyoloe_crn_s_300e_coco static training (#7859)

上级 bb1ba033
...@@ -142,6 +142,46 @@ TO_STATIC_SPEC = { ...@@ -142,6 +142,46 @@ TO_STATIC_SPEC = {
'centerness4': paddle.static.InputSpec( 'centerness4': paddle.static.InputSpec(
name='centerness4', shape=[-1, 10, 10, 1], dtype='float32'), name='centerness4', shape=[-1, 10, 10, 1], dtype='float32'),
}], }],
'picodet_s_320_coco_lcnet': [{
'im_id': paddle.static.InputSpec(
name='im_id', shape=[-1, 1], dtype='float32'),
'is_crowd': paddle.static.InputSpec(
name='is_crowd', shape=[-1, -1, 1], dtype='float32'),
'gt_class': paddle.static.InputSpec(
name='gt_class', shape=[-1, -1, 1], dtype='int32'),
'gt_bbox': paddle.static.InputSpec(
name='gt_bbox', shape=[-1, -1, 4], dtype='float32'),
'curr_iter': paddle.static.InputSpec(
name='curr_iter', shape=[-1], dtype='float32'),
'image': paddle.static.InputSpec(
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
'im_shape': paddle.static.InputSpec(
name='im_shape', shape=[-1, 2], dtype='float32'),
'scale_factor': paddle.static.InputSpec(
name='scale_factor', shape=[-1, 2], dtype='float32'),
'pad_gt_mask': paddle.static.InputSpec(
name='pad_gt_mask', shape=[-1, -1, 1], dtype='float32'),
}],
'ppyoloe_crn_s_300e_coco': [{
'im_id': paddle.static.InputSpec(
name='im_id', shape=[-1, 1], dtype='float32'),
'is_crowd': paddle.static.InputSpec(
name='is_crowd', shape=[-1, -1, 1], dtype='float32'),
'gt_class': paddle.static.InputSpec(
name='gt_class', shape=[-1, -1, 1], dtype='int32'),
'gt_bbox': paddle.static.InputSpec(
name='gt_bbox', shape=[-1, -1, 4], dtype='float32'),
'curr_iter': paddle.static.InputSpec(
name='curr_iter', shape=[-1], dtype='float32'),
'image': paddle.static.InputSpec(
name='image', shape=[-1, 3, -1, -1], dtype='float32'),
'im_shape': paddle.static.InputSpec(
name='im_shape', shape=[-1, 2], dtype='float32'),
'scale_factor': paddle.static.InputSpec(
name='scale_factor', shape=[-1, 2], dtype='float32'),
'pad_gt_mask': paddle.static.InputSpec(
name='pad_gt_mask', shape=[-1, -1, 1], dtype='float32'),
}],
} }
......
...@@ -169,8 +169,9 @@ class ATSSAssigner(nn.Layer): ...@@ -169,8 +169,9 @@ class ATSSAssigner(nn.Layer):
# the one with the highest iou will be selected. # the one with the highest iou will be selected.
mask_positive_sum = mask_positive.sum(axis=-2) mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1: if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile( mask_multiple_gts = (
[1, num_max_boxes, 1]) mask_positive_sum.unsqueeze(1) > 1).astype('int32').tile(
[1, num_max_boxes, 1]).astype('bool')
if self.sm_use: if self.sm_use:
is_max_iou = compute_max_iou_anchor(ious * mask_positive) is_max_iou = compute_max_iou_anchor(ious * mask_positive)
else: else:
...@@ -221,4 +222,4 @@ class ATSSAssigner(nn.Layer): ...@@ -221,4 +222,4 @@ class ATSSAssigner(nn.Layer):
paddle.zeros_like(gather_scores)) paddle.zeros_like(gather_scores))
assigned_scores *= gather_scores.unsqueeze(-1) assigned_scores *= gather_scores.unsqueeze(-1)
return assigned_labels, assigned_bboxes, assigned_scores, mask_positive return assigned_labels, assigned_bboxes, assigned_scores
...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer): ...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer):
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores, mask_positive return assigned_labels, assigned_bboxes, assigned_scores
...@@ -178,4 +178,4 @@ class TaskAlignedAssigner_CR(nn.Layer): ...@@ -178,4 +178,4 @@ class TaskAlignedAssigner_CR(nn.Layer):
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores, mask_positive return assigned_labels, assigned_bboxes, assigned_scores
...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead): ...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead):
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner( assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead): ...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead):
pred_bboxes=pred_bboxes.detach() * stride_tensor_list) pred_bboxes=pred_bboxes.detach() * stride_tensor_list)
else: else:
assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner( assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list, pred_bboxes.detach() * stride_tensor_list,
centers, centers,
......
...@@ -121,7 +121,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -121,7 +121,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores, _ = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -133,7 +133,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -133,7 +133,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
alpha_l = 0.25 alpha_l = 0.25
else: else:
if self.sm_use: if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores, _ = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -144,7 +144,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -144,7 +144,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask, pad_gt_mask,
bg_index=self.num_classes) bg_index=self.num_classes)
else: else:
assigned_labels, assigned_bboxes, assigned_scores, _ = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
......
...@@ -337,7 +337,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -337,7 +337,8 @@ class PPYOLOEHead(nn.Layer):
# pos/neg loss # pos/neg loss
if num_pos > 0: if num_pos > 0:
# l1 + iou # l1 + iou
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4]) bbox_mask = mask_positive.astype('int32').unsqueeze(-1).tile(
[1, 1, 4]).astype('bool')
pred_bboxes_pos = paddle.masked_select(pred_bboxes, pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4]) bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select( assigned_bboxes_pos = paddle.masked_select(
...@@ -351,8 +352,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -351,8 +352,8 @@ class PPYOLOEHead(nn.Layer):
assigned_bboxes_pos) * bbox_weight assigned_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / assigned_scores_sum loss_iou = loss_iou.sum() / assigned_scores_sum
dist_mask = mask_positive.unsqueeze(-1).tile( dist_mask = mask_positive.unsqueeze(-1).astype('int32').tile(
[1, 1, self.reg_channels * 4]) [1, 1, self.reg_channels * 4]).astype('bool')
pred_dist_pos = paddle.masked_select( pred_dist_pos = paddle.masked_select(
pred_dist, dist_mask).reshape([-1, 4, self.reg_channels]) pred_dist, dist_mask).reshape([-1, 4, self.reg_channels])
assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes) assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
...@@ -387,7 +388,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -387,7 +388,7 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -400,7 +401,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -400,7 +401,7 @@ class PPYOLOEHead(nn.Layer):
else: else:
if self.sm_use: if self.sm_use:
# only used in smalldet of PPYOLOE-SOD model # only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -413,7 +414,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -413,7 +414,7 @@ class PPYOLOEHead(nn.Layer):
else: else:
if aux_pred is None: if aux_pred is None:
if not hasattr(self, "assigned_labels"): if not hasattr(self, "assigned_labels"):
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -427,15 +428,15 @@ class PPYOLOEHead(nn.Layer): ...@@ -427,15 +428,15 @@ class PPYOLOEHead(nn.Layer):
self.assigned_labels = assigned_labels self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
else: else:
# only used in distill # only used in distill
assigned_labels = self.assigned_labels assigned_labels = self.assigned_labels
assigned_bboxes = self.assigned_bboxes assigned_bboxes = self.assigned_bboxes
assigned_scores = self.assigned_scores assigned_scores = self.assigned_scores
mask_positive = self.mask_positive
else: else:
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \ assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner( self.assigner(
pred_scores_aux.detach(), pred_scores_aux.detach(),
pred_bboxes_aux.detach() * stride_tensor, pred_bboxes_aux.detach() * stride_tensor,
...@@ -451,14 +452,12 @@ class PPYOLOEHead(nn.Layer): ...@@ -451,14 +452,12 @@ class PPYOLOEHead(nn.Layer):
assign_out_dict = self.get_loss_from_assign( assign_out_dict = self.get_loss_from_assign(
pred_scores, pred_distri, pred_bboxes, anchor_points_s, pred_scores, pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, mask_positive, assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
alpha_l)
if aux_pred is not None: if aux_pred is not None:
assign_out_dict_aux = self.get_loss_from_assign( assign_out_dict_aux = self.get_loss_from_assign(
aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s, aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
mask_positive, alpha_l)
loss = {} loss = {}
for key in assign_out_dict.keys(): for key in assign_out_dict.keys():
loss[key] = assign_out_dict[key] + assign_out_dict_aux[key] loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
...@@ -469,7 +468,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -469,7 +468,7 @@ class PPYOLOEHead(nn.Layer):
def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
anchor_points_s, assigned_labels, assigned_bboxes, anchor_points_s, assigned_labels, assigned_bboxes,
assigned_scores, mask_positive, alpha_l): assigned_scores, alpha_l):
# cls loss # cls loss
if self.use_varifocal_loss: if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, one_hot_label = F.one_hot(assigned_labels,
...@@ -490,7 +489,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -490,7 +489,7 @@ class PPYOLOEHead(nn.Layer):
self.distill_pairs['pred_cls_scores'] = pred_scores self.distill_pairs['pred_cls_scores'] = pred_scores
self.distill_pairs['pos_num'] = assigned_scores_sum self.distill_pairs['pos_num'] = assigned_scores_sum
self.distill_pairs['assigned_scores'] = assigned_scores self.distill_pairs['assigned_scores'] = assigned_scores
self.distill_pairs['mask_positive'] = mask_positive
one_hot_label = F.one_hot(assigned_labels, one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1] self.num_classes + 1)[..., :-1]
self.distill_pairs['target_labels'] = one_hot_label self.distill_pairs['target_labels'] = one_hot_label
......
...@@ -57,4 +57,6 @@ repeat:25 ...@@ -57,4 +57,6 @@ repeat:25
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:null flags:null
===========================infer_benchmark_params=========================== ===========================infer_benchmark_params===========================
numpy_infer_input:3x320x320_2.npy numpy_infer_input:3x320x320_2.npy
\ No newline at end of file ===========================to_static_train_benchmark_params===========================
to_static_train:--to_static
\ No newline at end of file
...@@ -57,4 +57,4 @@ repeat:12 ...@@ -57,4 +57,4 @@ repeat:12
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:null flags:null
===========================infer_benchmark_params=========================== ===========================infer_benchmark_params===========================
numpy_infer_input:3x640x640_2.npy numpy_infer_input:3x640x640_2.npy
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册