From 0a49f80ccbfff7c571a9f920a069aa5bbc19624f Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Wed, 8 Feb 2023 21:36:33 +0800 Subject: [PATCH] add mainKD for ppyoloe distill (#7708) * add mainkd for ppyoloe distill * refine code style --- .../ppyoloe_plus_distill_m_distill_s.yml | 2 ++ ppdet/modeling/heads/ppyoloe_head.py | 1 + ppdet/slim/distill_loss.py | 31 ++++++++++++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/configs/slim/distill/ppyoloe_plus_distill_m_distill_s.yml b/configs/slim/distill/ppyoloe_plus_distill_m_distill_s.yml index 869e1bc2d..8ee944e9b 100644 --- a/configs/slim/distill/ppyoloe_plus_distill_m_distill_s.yml +++ b/configs/slim/distill/ppyoloe_plus_distill_m_distill_s.yml @@ -43,6 +43,8 @@ DistillPPYOLOELoss: # M -> S loss_weight: {'logits': 4.0, 'feat': 1.0} logits_distill: True logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} + logits_ld_distill: True + logits_ld_params: {'weight': 20000, 'T': 10} feat_distill: True feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] feat_distill_place: 'neck_feats' diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 1eb735194..201aa4812 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -323,6 +323,7 @@ class PPYOLOEHead(nn.Layer): assigned_bboxes, assigned_scores, assigned_scores_sum): # select positive samples mask mask_positive = (assigned_labels != self.num_classes) + self.distill_pairs['mask_positive_select'] = mask_positive num_pos = mask_positive.sum() # pos/neg loss if num_pos > 0: diff --git a/ppdet/slim/distill_loss.py b/ppdet/slim/distill_loss.py index a0539277f..6e94fd841 100644 --- a/ppdet/slim/distill_loss.py +++ b/ppdet/slim/distill_loss.py @@ -212,6 +212,9 @@ class DistillPPYOLOELoss(nn.Layer): logits_loss_weight={'class': 1.0, 'iou': 2.5, 'dfl': 0.5}, + logits_ld_distill=False, + logits_ld_params={'weight': 20000, + 'T': 10}, feat_distill=True, feat_distiller='fgd', feat_distill_place='neck_feats', @@ -222,6 +225,7 @@ class DistillPPYOLOELoss(nn.Layer): self.loss_weight_logits = loss_weight['logits'] self.loss_weight_feat = loss_weight['feat'] self.logits_distill = logits_distill + self.logits_ld_distill = logits_ld_distill self.feat_distill = feat_distill if logits_distill and self.loss_weight_logits > 0: @@ -230,6 +234,10 @@ class DistillPPYOLOELoss(nn.Layer): self.qfl_loss_weight = logits_loss_weight['class'] self.loss_bbox = GIoULoss() + if logits_ld_distill: + self.loss_kd = KnowledgeDistillationKLDivLoss( + loss_weight=logits_ld_params['weight'], T=logits_ld_params['T']) + if feat_distill and self.loss_weight_feat > 0: assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] assert feat_distill_place in ['backbone_feats', 'neck_feats'] @@ -334,6 +342,20 @@ class DistillPPYOLOELoss(nn.Layer): loss_dfl = loss_dfl.mean(-1) return loss_dfl / 4.0 # 4 direction + def main_kd(self, mask_positive, pred_scores, soft_cls, num_classes): + num_pos = mask_positive.sum() + if num_pos > 0: + cls_mask = mask_positive.unsqueeze(-1).tile([1, 1, num_classes]) + pred_scores_pos = paddle.masked_select( + pred_scores, cls_mask).reshape([-1, num_classes]) + soft_cls_pos = paddle.masked_select( + soft_cls, cls_mask).reshape([-1, num_classes]) + loss_kd = self.loss_kd( + pred_scores_pos, soft_cls_pos, avg_factor=num_pos) + else: + loss_kd = paddle.zeros([1]) + return loss_kd + def forward(self, teacher_model, student_model): teacher_distill_pairs = teacher_model.yolo_head.distill_pairs student_distill_pairs = student_model.yolo_head.distill_pairs @@ -373,8 +395,15 @@ class DistillPPYOLOELoss(nn.Layer): distill_cls_loss = paddle.add_n(distill_cls_loss) distill_bbox_loss = paddle.add_n(distill_bbox_loss) distill_dfl_loss = paddle.add_n(distill_dfl_loss) - logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight + + if self.logits_ld_distill: + loss_kd = self.main_kd( + student_distill_pairs['mask_positive_select'], + student_distill_pairs['pred_cls_scores'], + teacher_distill_pairs['pred_cls_scores'], + student_model.yolo_head.num_classes, ) + logits_loss += loss_kd else: logits_loss = paddle.zeros([1]) -- GitLab