未验证 提交 0a49f80c 编写于 作者: H huangjun12 提交者: GitHub

add mainKD for ppyoloe distill (#7708)

* add mainkd for ppyoloe distill

* refine code style
上级 3b15b2e7
...@@ -43,6 +43,8 @@ DistillPPYOLOELoss: # M -> S ...@@ -43,6 +43,8 @@ DistillPPYOLOELoss: # M -> S
loss_weight: {'logits': 4.0, 'feat': 1.0} loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5} logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
logits_ld_distill: True
logits_ld_params: {'weight': 20000, 'T': 10}
feat_distill: True feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats' feat_distill_place: 'neck_feats'
......
...@@ -323,6 +323,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -323,6 +323,7 @@ class PPYOLOEHead(nn.Layer):
assigned_bboxes, assigned_scores, assigned_scores_sum): assigned_bboxes, assigned_scores, assigned_scores_sum):
# select positive samples mask # select positive samples mask
mask_positive = (assigned_labels != self.num_classes) mask_positive = (assigned_labels != self.num_classes)
self.distill_pairs['mask_positive_select'] = mask_positive
num_pos = mask_positive.sum() num_pos = mask_positive.sum()
# pos/neg loss # pos/neg loss
if num_pos > 0: if num_pos > 0:
......
...@@ -212,6 +212,9 @@ class DistillPPYOLOELoss(nn.Layer): ...@@ -212,6 +212,9 @@ class DistillPPYOLOELoss(nn.Layer):
logits_loss_weight={'class': 1.0, logits_loss_weight={'class': 1.0,
'iou': 2.5, 'iou': 2.5,
'dfl': 0.5}, 'dfl': 0.5},
logits_ld_distill=False,
logits_ld_params={'weight': 20000,
'T': 10},
feat_distill=True, feat_distill=True,
feat_distiller='fgd', feat_distiller='fgd',
feat_distill_place='neck_feats', feat_distill_place='neck_feats',
...@@ -222,6 +225,7 @@ class DistillPPYOLOELoss(nn.Layer): ...@@ -222,6 +225,7 @@ class DistillPPYOLOELoss(nn.Layer):
self.loss_weight_logits = loss_weight['logits'] self.loss_weight_logits = loss_weight['logits']
self.loss_weight_feat = loss_weight['feat'] self.loss_weight_feat = loss_weight['feat']
self.logits_distill = logits_distill self.logits_distill = logits_distill
self.logits_ld_distill = logits_ld_distill
self.feat_distill = feat_distill self.feat_distill = feat_distill
if logits_distill and self.loss_weight_logits > 0: if logits_distill and self.loss_weight_logits > 0:
...@@ -230,6 +234,10 @@ class DistillPPYOLOELoss(nn.Layer): ...@@ -230,6 +234,10 @@ class DistillPPYOLOELoss(nn.Layer):
self.qfl_loss_weight = logits_loss_weight['class'] self.qfl_loss_weight = logits_loss_weight['class']
self.loss_bbox = GIoULoss() self.loss_bbox = GIoULoss()
if logits_ld_distill:
self.loss_kd = KnowledgeDistillationKLDivLoss(
loss_weight=logits_ld_params['weight'], T=logits_ld_params['T'])
if feat_distill and self.loss_weight_feat > 0: if feat_distill and self.loss_weight_feat > 0:
assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic'] assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
assert feat_distill_place in ['backbone_feats', 'neck_feats'] assert feat_distill_place in ['backbone_feats', 'neck_feats']
...@@ -334,6 +342,20 @@ class DistillPPYOLOELoss(nn.Layer): ...@@ -334,6 +342,20 @@ class DistillPPYOLOELoss(nn.Layer):
loss_dfl = loss_dfl.mean(-1) loss_dfl = loss_dfl.mean(-1)
return loss_dfl / 4.0 # 4 direction return loss_dfl / 4.0 # 4 direction
def main_kd(self, mask_positive, pred_scores, soft_cls, num_classes):
num_pos = mask_positive.sum()
if num_pos > 0:
cls_mask = mask_positive.unsqueeze(-1).tile([1, 1, num_classes])
pred_scores_pos = paddle.masked_select(
pred_scores, cls_mask).reshape([-1, num_classes])
soft_cls_pos = paddle.masked_select(
soft_cls, cls_mask).reshape([-1, num_classes])
loss_kd = self.loss_kd(
pred_scores_pos, soft_cls_pos, avg_factor=num_pos)
else:
loss_kd = paddle.zeros([1])
return loss_kd
def forward(self, teacher_model, student_model): def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.distill_pairs teacher_distill_pairs = teacher_model.yolo_head.distill_pairs
student_distill_pairs = student_model.yolo_head.distill_pairs student_distill_pairs = student_model.yolo_head.distill_pairs
...@@ -373,8 +395,15 @@ class DistillPPYOLOELoss(nn.Layer): ...@@ -373,8 +395,15 @@ class DistillPPYOLOELoss(nn.Layer):
distill_cls_loss = paddle.add_n(distill_cls_loss) distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_bbox_loss = paddle.add_n(distill_bbox_loss) distill_bbox_loss = paddle.add_n(distill_bbox_loss)
distill_dfl_loss = paddle.add_n(distill_dfl_loss) distill_dfl_loss = paddle.add_n(distill_dfl_loss)
logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight
if self.logits_ld_distill:
loss_kd = self.main_kd(
student_distill_pairs['mask_positive_select'],
student_distill_pairs['pred_cls_scores'],
teacher_distill_pairs['pred_cls_scores'],
student_model.yolo_head.num_classes, )
logits_loss += loss_kd
else: else:
logits_loss = paddle.zeros([1]) logits_loss = paddle.zeros([1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册