未验证 提交 b2e0acf1 编写于 作者: F Feng Ni 提交者: GitHub

Add PPYOLOE distillation (#7671)

* add ppyoloe distill with soft loss and feature loss

* refine distill codes

* fix configs and docs

* clean codes

* merge cam, fix export
上级 2b5fd266
# PPYOLOE+ Distillation(PPYOLOE+ 蒸馏)
PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。
## 模型库
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
_BASE_: [
'../ppyoloe_plus_crn_l_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
_BASE_: [
'../ppyoloe_plus_crn_m_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_m_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_m_obj365_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
_BASE_: [
'../ppyoloe_plus_crn_s_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
...@@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https:// ...@@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https://
|gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | |gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) |
|gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | |gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) |
## PPYOLOE+模型蒸馏
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
## Citations ## Citations
``` ```
......
...@@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_ ...@@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_
slim: Distill slim: Distill
slim_method: CWD slim_method: CWD
distill_loss: ChannelWiseDivergence distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence: CWDFeatureLoss:
student_channels: 80 student_channels: 80
teacher_channels: 80 teacher_channels: 80
tau: 1.0 tau: 1.0
......
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 1.0
width_mult: 1.0
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # L -> M
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.0 # L
student_width_mult: 0.75 # M
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 0.67
width_mult: 0.75
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # M -> S
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 0.75 # M
student_width_mult: 0.5 # S
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml',
]
depth_mult: 1.33
width_mult: 1.25
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # X -> L
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.25 # X
student_width_mult: 1.0 # L
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
...@@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c ...@@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c
slim: Distill slim: Distill
slim_method: CWD slim_method: CWD
distill_loss: ChannelWiseDivergence distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence: CWDFeatureLoss:
student_channels: 80 student_channels: 80
teacher_channels: 80 teacher_channels: 80
name: cwdloss
tau: 1.0 tau: 1.0
weight: 5.0 weight: 5.0
...@@ -22,13 +22,14 @@ from ppdet.core.workspace import register, create ...@@ -22,13 +22,14 @@ from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture # PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py # PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head
@register @register
class PPYOLOE(BaseArch): class PPYOLOE(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
__shared__ = ['for_distill']
__inject__ = ['post_process'] __inject__ = ['post_process']
def __init__(self, def __init__(self,
...@@ -36,6 +37,8 @@ class PPYOLOE(BaseArch): ...@@ -36,6 +37,8 @@ class PPYOLOE(BaseArch):
neck='CustomCSPPAN', neck='CustomCSPPAN',
yolo_head='PPYOLOEHead', yolo_head='PPYOLOEHead',
post_process='BBoxPostProcess', post_process='BBoxPostProcess',
for_distill=False,
feat_distill_place='neck_feats',
for_mot=False): for_mot=False):
""" """
PPYOLOE network, see https://arxiv.org/abs/2203.16250 PPYOLOE network, see https://arxiv.org/abs/2203.16250
...@@ -54,6 +57,10 @@ class PPYOLOE(BaseArch): ...@@ -54,6 +57,10 @@ class PPYOLOE(BaseArch):
self.yolo_head = yolo_head self.yolo_head = yolo_head
self.post_process = post_process self.post_process = post_process
self.for_mot = for_mot self.for_mot = for_mot
self.for_distill = for_distill
self.feat_distill_place = feat_distill_place
if for_distill:
assert feat_distill_place in ['backbone_feats', 'neck_feats']
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -80,17 +87,31 @@ class PPYOLOE(BaseArch): ...@@ -80,17 +87,31 @@ class PPYOLOE(BaseArch):
if self.training: if self.training:
yolo_losses = self.yolo_head(neck_feats, self.inputs) yolo_losses = self.yolo_head(neck_feats, self.inputs)
if self.for_distill:
if self.feat_distill_place == 'backbone_feats':
self.yolo_head.distill_pairs['backbone_feats'] = body_feats
elif self.feat_distill_place == 'neck_feats':
self.yolo_head.distill_pairs['neck_feats'] = neck_feats
else:
raise ValueError
return yolo_losses return yolo_losses
else: else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats) yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None: if self.post_process is not None:
bbox, bbox_num = self.post_process( bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else: else:
bbox, bbox_num = self.yolo_head.post_process( bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor']) yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num} # data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output return output
...@@ -180,15 +201,21 @@ class PPYOLOEWithAuxHead(BaseArch): ...@@ -180,15 +201,21 @@ class PPYOLOEWithAuxHead(BaseArch):
aux_pred=[aux_cls_scores, aux_bbox_preds]) aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss return loss
else: else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats) yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None: if self.post_process is not None:
bbox, bbox_num = self.post_process( bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else: else:
bbox, bbox_num = self.yolo_head.post_process( bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor']) yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num} # data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output return output
......
...@@ -22,7 +22,7 @@ from ..post_process import JDEBBoxPostProcess ...@@ -22,7 +22,7 @@ from ..post_process import JDEBBoxPostProcess
__all__ = ['YOLOv3'] __all__ = ['YOLOv3']
# YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3 # YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py # PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py, especially when use distillation or aux head
@register @register
......
...@@ -221,4 +221,4 @@ class ATSSAssigner(nn.Layer): ...@@ -221,4 +221,4 @@ class ATSSAssigner(nn.Layer):
paddle.zeros_like(gather_scores)) paddle.zeros_like(gather_scores))
assigned_scores *= gather_scores.unsqueeze(-1) assigned_scores *= gather_scores.unsqueeze(-1)
return assigned_labels, assigned_bboxes, assigned_scores return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer): ...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer):
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead): ...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead):
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead): ...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead):
pred_bboxes=pred_bboxes.detach() * stride_tensor_list) pred_bboxes=pred_bboxes.detach() * stride_tensor_list)
else: else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list, pred_bboxes.detach() * stride_tensor_list,
centers, centers,
......
...@@ -136,7 +136,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -136,7 +136,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -148,7 +148,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -148,7 +148,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
alpha_l = 0.25 alpha_l = 0.25
else: else:
if self.sm_use: if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -159,7 +159,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -159,7 +159,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask, pad_gt_mask,
bg_index=self.num_classes) bg_index=self.num_classes)
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
......
...@@ -53,7 +53,7 @@ class ESEAttn(nn.Layer): ...@@ -53,7 +53,7 @@ class ESEAttn(nn.Layer):
class PPYOLOEHead(nn.Layer): class PPYOLOEHead(nn.Layer):
__shared__ = [ __shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms', 'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv' 'exclude_post_process', 'use_shared_conv', 'for_distill'
] ]
__inject__ = ['static_assigner', 'assigner', 'nms'] __inject__ = ['static_assigner', 'assigner', 'nms']
...@@ -81,7 +81,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -81,7 +81,8 @@ class PPYOLOEHead(nn.Layer):
attn_conv='convbn', attn_conv='convbn',
exclude_nms=False, exclude_nms=False,
exclude_post_process=False, exclude_post_process=False,
use_shared_conv=True): use_shared_conv=True,
for_distill=False):
super(PPYOLOEHead, self).__init__() super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0" assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels self.in_channels = in_channels
...@@ -110,6 +111,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -110,6 +111,7 @@ class PPYOLOEHead(nn.Layer):
self.exclude_nms = exclude_nms self.exclude_nms = exclude_nms
self.exclude_post_process = exclude_post_process self.exclude_post_process = exclude_post_process
self.use_shared_conv = use_shared_conv self.use_shared_conv = use_shared_conv
self.for_distill = for_distill
# stem # stem
self.stem_cls = nn.LayerList() self.stem_cls = nn.LayerList()
...@@ -135,6 +137,9 @@ class PPYOLOEHead(nn.Layer): ...@@ -135,6 +137,9 @@ class PPYOLOEHead(nn.Layer):
self.proj_conv.skip_quant = True self.proj_conv.skip_quant = True
self._init_weights() self._init_weights()
if self.for_distill:
self.distill_pairs = {}
@classmethod @classmethod
def from_config(cls, cfg, input_shape): def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], } return {'in_channels': [i.channels for i in input_shape], }
...@@ -321,6 +326,10 @@ class PPYOLOEHead(nn.Layer): ...@@ -321,6 +326,10 @@ class PPYOLOEHead(nn.Layer):
loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos, loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos,
self.reg_range[0]) * bbox_weight self.reg_range[0]) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum loss_dfl = loss_dfl.sum() / assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos
self.distill_pairs['pred_dist_pos'] = pred_dist_pos
self.distill_pairs['bbox_weight'] = bbox_weight
else: else:
loss_l1 = paddle.zeros([1]) loss_l1 = paddle.zeros([1])
loss_iou = paddle.zeros([1]) loss_iou = paddle.zeros([1])
...@@ -343,7 +352,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -343,7 +352,7 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -356,7 +365,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -356,7 +365,7 @@ class PPYOLOEHead(nn.Layer):
else: else:
if self.sm_use: if self.sm_use:
# only used in smalldet of PPYOLOE-SOD model # only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -368,18 +377,28 @@ class PPYOLOEHead(nn.Layer): ...@@ -368,18 +377,28 @@ class PPYOLOEHead(nn.Layer):
bg_index=self.num_classes) bg_index=self.num_classes)
else: else:
if aux_pred is None: if aux_pred is None:
assigned_labels, assigned_bboxes, assigned_scores = \ if not hasattr(self, "assigned_labels"):
self.assigner( assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
pred_scores.detach(), self.assigner(
pred_bboxes.detach() * stride_tensor, pred_scores.detach(),
anchor_points, pred_bboxes.detach() * stride_tensor,
num_anchors_list, anchor_points,
gt_labels, num_anchors_list,
gt_bboxes, gt_labels,
pad_gt_mask, gt_bboxes,
bg_index=self.num_classes) pad_gt_mask,
bg_index=self.num_classes)
self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
else:
assigned_labels = self.assigned_labels
assigned_bboxes = self.assigned_bboxes
assigned_scores = self.assigned_scores
mask_positive = self.mask_positive
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner( self.assigner(
pred_scores_aux.detach(), pred_scores_aux.detach(),
pred_bboxes_aux.detach() * stride_tensor, pred_bboxes_aux.detach() * stride_tensor,
...@@ -395,12 +414,14 @@ class PPYOLOEHead(nn.Layer): ...@@ -395,12 +414,14 @@ class PPYOLOEHead(nn.Layer):
assign_out_dict = self.get_loss_from_assign( assign_out_dict = self.get_loss_from_assign(
pred_scores, pred_distri, pred_bboxes, anchor_points_s, pred_scores, pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l) assigned_labels, assigned_bboxes, assigned_scores, mask_positive,
alpha_l)
if aux_pred is not None: if aux_pred is not None:
assign_out_dict_aux = self.get_loss_from_assign( assign_out_dict_aux = self.get_loss_from_assign(
aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s, aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l) assigned_labels, assigned_bboxes, assigned_scores,
mask_positive, alpha_l)
loss = {} loss = {}
for key in assign_out_dict.keys(): for key in assign_out_dict.keys():
loss[key] = assign_out_dict[key] + assign_out_dict_aux[key] loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
...@@ -411,7 +432,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -411,7 +432,7 @@ class PPYOLOEHead(nn.Layer):
def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
anchor_points_s, assigned_labels, assigned_bboxes, anchor_points_s, assigned_labels, assigned_bboxes,
assigned_scores, alpha_l): assigned_scores, mask_positive, alpha_l):
# cls loss # cls loss
if self.use_varifocal_loss: if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, one_hot_label = F.one_hot(assigned_labels,
...@@ -428,6 +449,15 @@ class PPYOLOEHead(nn.Layer): ...@@ -428,6 +449,15 @@ class PPYOLOEHead(nn.Layer):
assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
loss_cls /= assigned_scores_sum loss_cls /= assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_cls_scores'] = pred_scores
self.distill_pairs['pos_num'] = assigned_scores_sum
self.distill_pairs['assigned_scores'] = assigned_scores
self.distill_pairs['mask_positive'] = mask_positive
one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1]
self.distill_pairs['target_labels'] = one_hot_label
loss_l1, loss_iou, loss_dfl = \ loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, assigned_labels, assigned_bboxes, assigned_scores,
...@@ -450,7 +480,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -450,7 +480,8 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes *= stride_tensor pred_bboxes *= stride_tensor
if self.exclude_post_process: if self.exclude_post_process:
return paddle.concat( return paddle.concat(
[pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1), None [pred_bboxes, pred_scores.transpose([0, 2, 1])],
axis=-1), None, None
else: else:
# scale bbox to origin # scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
...@@ -460,9 +491,10 @@ class PPYOLOEHead(nn.Layer): ...@@ -460,9 +491,10 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes /= scale_factor pred_bboxes /= scale_factor
if self.exclude_nms: if self.exclude_nms:
# `exclude_nms=True` just use in benchmark # `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores return pred_bboxes, pred_scores, None
else: else:
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, pred_scores) bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
pred_scores)
return bbox_pred, bbox_num, before_nms_indexes return bbox_pred, bbox_num, before_nms_indexes
......
...@@ -258,7 +258,7 @@ class PPYOLOERHead(nn.Layer): ...@@ -258,7 +258,7 @@ class PPYOLOERHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner( self.static_assigner(
anchor_points, anchor_points,
stride_tensor, stride_tensor,
...@@ -271,7 +271,7 @@ class PPYOLOERHead(nn.Layer): ...@@ -271,7 +271,7 @@ class PPYOLOERHead(nn.Layer):
pred_bboxes.detach() pred_bboxes.detach()
) )
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach(), pred_bboxes.detach(),
......
...@@ -293,7 +293,7 @@ class TOODHead(nn.Layer): ...@@ -293,7 +293,7 @@ class TOODHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -302,7 +302,7 @@ class TOODHead(nn.Layer): ...@@ -302,7 +302,7 @@ class TOODHead(nn.Layer):
bg_index=self.num_classes) bg_index=self.num_classes)
alpha_l = 0.25 alpha_l = 0.25
else: else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
bbox_center(anchors), bbox_center(anchors),
......
...@@ -12,16 +12,19 @@ ...@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import distill_loss
from . import distill_model
from . import ofa
from . import prune from . import prune
from . import quant from . import quant
from . import distill
from . import unstructured_prune from . import unstructured_prune
from .distill_loss import *
from .distill_model import *
from .ofa import *
from .prune import * from .prune import *
from .quant import * from .quant import *
from .distill import *
from .unstructured_prune import * from .unstructured_prune import *
from .ofa import *
import yaml import yaml
from ppdet.core.workspace import load_config from ppdet.core.workspace import load_config
...@@ -45,7 +48,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'): ...@@ -45,7 +48,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
elif "slim_method" in slim_load_cfg and slim_load_cfg[ elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "CWD": 'slim_method'] == "CWD":
model = CWDDistillModel(cfg, slim_cfg) model = CWDDistillModel(cfg, slim_cfg)
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "PPYOLOEDistill":
model = PPYOLOEDistillModel(cfg, slim_cfg)
else: else:
# common distillation model
model = DistillModel(cfg, slim_cfg) model = DistillModel(cfg, slim_cfg)
cfg['model'] = model cfg['model'] = model
cfg['slim_type'] = cfg.slim cfg['slim_type'] = cfg.slim
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,302 +16,398 @@ from __future__ import absolute_import ...@@ -16,302 +16,398 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config from ppdet.core.workspace import register, create
from ppdet.modeling import ops from ppdet.modeling import ops
from ppdet.utils.checkpoint import load_pretrain_weight from ppdet.modeling.losses.iou_loss import GIoULoss
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
__all__ = [
'DistillYOLOv3Loss',
'KnowledgeDistillationKLDivLoss',
'DistillPPYOLOELoss',
'FGDFeatureLoss',
'CWDFeatureLoss',
'PKDFeatureLoss',
'MGDFeatureLoss',
]
class DistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.student_model = create(cfg.architecture) def parameter_init(mode="kaiming", value=0.):
logger.debug('Load student model pretrain_weights:{}'.format( if mode == "kaiming":
cfg.pretrain_weights)) weight_attr = paddle.nn.initializer.KaimingUniform()
load_pretrain_weight(self.student_model, cfg.pretrain_weights) elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
slim_cfg = load_config(slim_cfg) weight_init = ParamAttr(initializer=weight_attr)
return weight_init
self.teacher_model = create(slim_cfg.architecture)
self.distill_loss = create(slim_cfg.distill_loss)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters(): def feature_norm(feat):
param.trainable = False # Normalize the feature maps to have zero mean and unit variances.
assert len(feat.shape) == 4
N, C, H, W = feat.shape
feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1])
mean = feat.mean(axis=-1, keepdim=True)
std = feat.std(axis=-1, keepdim=True)
feat = (feat - mean) / (std + 1e-6)
return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3])
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs): @register
if self.training: class DistillYOLOv3Loss(nn.Layer):
teacher_loss = self.teacher_model(inputs) def __init__(self, weight=1000):
student_loss = self.student_model(inputs) super(DistillYOLOv3Loss, self).__init__()
loss = self.distill_loss(self.teacher_model, self.student_model) self.loss_weight = weight
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss'] def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
student_loss['loss'] += student_loss['distill_loss'] loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
return student_loss loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
else: loss_w = paddle.abs(sw - tw)
return self.student_model(inputs) loss_h = paddle.abs(sh - th)
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(self, scls, tcls, tobj):
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
return weighted_loss
def obj_loss(self, sobj, tobj):
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = paddle.mean(
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
student_distill_pairs = student_model.yolo_head.loss.distill_pairs
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
distill_reg_loss.append(
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
distill_cls_loss.append(
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
distill_reg_loss = paddle.add_n(distill_reg_loss)
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_obj_loss = paddle.add_n(distill_obj_loss)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.loss_weight
return loss
@register
class KnowledgeDistillationKLDivLoss(nn.Layer):
"""Loss function for knowledge distilling using KL divergence.
class FGDDistillModel(nn.Layer):
"""
Build FGD distill model.
Args: Args:
cfg: The student config. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
slim_cfg: The teacher and distill config. loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
""" """
def __init__(self, cfg, slim_cfg): def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(FGDDistillModel, self).__init__() super(KnowledgeDistillationKLDivLoss, self).__init__()
assert reduction in ('none', 'mean', 'sum')
self.is_inherit = True assert T >= 1
# build student model before load slim config self.reduction = reduction
self.student_model = create(cfg.architecture) self.loss_weight = loss_weight
self.arch = cfg.architecture self.T = T
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg) def knowledge_distillation_kl_div_loss(self,
self.teacher_cfg = slim_cfg pred,
self.loss_cfg = slim_cfg soft_label,
tea_pretrain = cfg['pretrain_weights'] T,
detach_target=True):
self.teacher_model = create(self.teacher_cfg.architecture) r"""Loss function for knowledge distilling using KL divergence.
self.teacher_model.eval()
Args:
for param in self.teacher_model.parameters(): pred (Tensor): Predicted logits with shape (N, n + 1).
param.trainable = False soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
if 'pretrain_weights' in cfg and stu_pretrain: detach_target (bool): Remove soft_label from automatic differentiation
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: """
load_pretrain_weight(self.student_model, assert pred.shape == soft_label.shape
self.teacher_cfg.pretrain_weights) target = F.softmax(soft_label / T, axis=1)
logger.debug( if detach_target:
"Inheriting! loading teacher weights to student model!") target = target.detach()
load_pretrain_weight(self.student_model, stu_pretrain) kd_loss = F.kl_div(
F.log_softmax(
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: pred / T, axis=1), target, reduction='none').mean(1) * (T * T)
load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights) return kd_loss
self.fgd_loss_dic = self.build_loss( def forward(self,
self.loss_cfg.distill_loss, pred,
name_list=self.loss_cfg['distill_loss_name']) soft_label,
weight=None,
def build_loss(self, avg_factor=None,
cfg, reduction_override=None):
name_list=[ """Forward function.
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0' Args:
]): pred (Tensor): Predicted logits with shape (N, n + 1).
loss_func = dict() soft_label (Tensor): Target logits with shape (N, N + 1).
for idx, k in enumerate(name_list): weight (Tensor, optional): The weight of loss for each
loss_func[k] = create(cfg) prediction. Defaults to None.
return loss_func avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
def forward(self, inputs): reduction_override (str, optional): The reduction method used to
if self.training: override the original reduction method of the loss.
s_body_feats = self.student_model.backbone(inputs) Defaults to None.
s_neck_feats = self.student_model.neck(s_body_feats) """
assert reduction_override in (None, 'none', 'mean', 'sum')
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs) reduction = (reduction_override
t_neck_feats = self.teacher_model.neck(t_body_feats) if reduction_override else self.reduction)
loss_dict = {} loss_kd_out = self.knowledge_distillation_kl_div_loss(
for idx, k in enumerate(self.fgd_loss_dic): pred, soft_label, T=self.T)
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
t_neck_feats[idx], inputs) if weight is not None:
if self.arch == "RetinaNet": loss_kd_out = weight * loss_kd_out
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet": if avg_factor is None:
head_outs = self.student_model.head( if reduction == 'none':
s_neck_feats, self.student_model.export_post_process) loss = loss_kd_out
loss_gfl = self.student_model.head.get_loss(head_outs, inputs) elif reduction == 'mean':
total_loss = paddle.add_n(list(loss_gfl.values())) loss = loss_kd_out.mean()
loss = {} elif reduction == 'sum':
loss.update(loss_gfl) loss = loss_kd_out.sum()
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else: else:
body_feats = self.student_model.backbone(inputs) # if reduction is mean, then average the loss by avg_factor
neck_feats = self.student_model.neck(body_feats) if reduction == 'mean':
head_outs = self.student_model.head(neck_feats) loss = loss_kd_out.sum() / avg_factor
if self.arch == "RetinaNet": # if reduction is 'none', then do nothing, otherwise raise an error
bbox, bbox_num = self.student_model.head.post_process( elif reduction != 'none':
head_outs, inputs['im_shape'], inputs['scale_factor']) raise ValueError(
return {'bbox': bbox, 'bbox_num': bbox_num} 'avg_factor can not be used with reduction="sum"')
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
loss_kd = self.loss_weight * loss
return loss_kd
class CWDDistillModel(nn.Layer):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg): @register
super(CWDDistillModel, self).__init__() class DistillPPYOLOELoss(nn.Layer):
def __init__(
self.is_inherit = False self,
# build student model before load slim config loss_weight={'logits': 4.0,
self.student_model = create(cfg.architecture) 'feat': 1.0},
self.arch = cfg.architecture logits_distill=True,
if self.arch not in ['GFL', 'RetinaNet']: logits_loss_weight={'class': 1.0,
raise ValueError( 'iou': 2.5,
f"The arch can only be one of ['GFL', 'RetinaNet'], but received {self.arch}" 'dfl': 0.5},
) feat_distill=True,
feat_distiller='fgd',
stu_pretrain = cfg['pretrain_weights'] feat_distill_place='neck_feats',
slim_cfg = load_config(slim_cfg) teacher_width_mult=1.0, # L
self.teacher_cfg = slim_cfg student_width_mult=0.75, # M
self.loss_cfg = slim_cfg feat_out_channels=[768, 384, 192]):
tea_pretrain = cfg['pretrain_weights'] super(DistillPPYOLOELoss, self).__init__()
self.loss_weight_logits = loss_weight['logits']
self.teacher_model = create(self.teacher_cfg.architecture) self.loss_weight_feat = loss_weight['feat']
self.teacher_model.eval() self.logits_distill = logits_distill
self.feat_distill = feat_distill
for param in self.teacher_model.parameters():
param.trainable = False if logits_distill and self.loss_weight_logits > 0:
if 'pretrain_weights' in cfg and stu_pretrain: self.bbox_loss_weight = logits_loss_weight['iou']
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: self.dfl_loss_weight = logits_loss_weight['dfl']
load_pretrain_weight(self.student_model, self.qfl_loss_weight = logits_loss_weight['class']
self.teacher_cfg.pretrain_weights) self.loss_bbox = GIoULoss()
logger.debug(
"Inheriting! loading teacher weights to student model!") if feat_distill and self.loss_weight_feat > 0:
assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
load_pretrain_weight(self.student_model, stu_pretrain) assert feat_distill_place in ['backbone_feats', 'neck_feats']
self.feat_distill_place = feat_distill_place
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: self.t_channel_list = [
load_pretrain_weight(self.teacher_model, int(c * teacher_width_mult) for c in feat_out_channels
self.teacher_cfg.pretrain_weights) ]
self.s_channel_list = [
self.loss_dic = self.build_loss( int(c * student_width_mult) for c in feat_out_channels
self.loss_cfg.distill_loss, ]
name_list=self.loss_cfg['distill_loss_name']) self.distill_feat_loss_modules = []
for i in range(len(feat_out_channels)):
def build_loss(self, if feat_distiller == 'cwd':
cfg, feat_loss_module = CWDFeatureLoss(
name_list=[ student_channels=self.s_channel_list[i],
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', teacher_channels=self.t_channel_list[i],
'neck_f_0' normalize=True)
]): elif feat_distiller == 'fgd':
loss_func = dict() feat_loss_module = FGDFeatureLoss(
for idx, k in enumerate(name_list): student_channels=self.s_channel_list[i],
loss_func[k] = create(cfg) teacher_channels=self.t_channel_list[i],
return loss_func normalize=True,
alpha_fgd=0.00001,
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs): beta_fgd=0.000005,
loss = self.student_model.head(stu_fea_list, inputs) gamma_fgd=0.00001,
distill_loss = {} lambda_fgd=0.00000005)
# cwd kd loss elif feat_distiller == 'pkd':
for idx, k in enumerate(self.loss_dic): feat_loss_module = PKDFeatureLoss(
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx], student_channels=self.s_channel_list[i],
tea_fea_list[idx]) teacher_channels=self.t_channel_list[i],
normalize=True,
loss['loss'] += distill_loss[k] resize_stu=True)
loss[k] = distill_loss[k] elif feat_distiller == 'mgd':
feat_loss_module = MGDFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True,
loss_func='ssim')
elif feat_distiller == 'mimic':
feat_loss_module = MimicFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True)
else:
raise ValueError
self.distill_feat_loss_modules.append(feat_loss_module)
def quality_focal_loss(self,
pred_logits,
soft_target_logits,
beta=2.0,
use_sigmoid=False,
num_total_pos=None):
if use_sigmoid:
func = F.binary_cross_entropy_with_logits
soft_target = F.sigmoid(soft_target_logits)
pred_sigmoid = F.sigmoid(pred_logits)
preds = pred_logits
else:
func = F.binary_cross_entropy
soft_target = soft_target_logits
pred_sigmoid = pred_logits
preds = pred_sigmoid
scale_factor = pred_sigmoid - soft_target
loss = func(
preds, soft_target, reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(1)
if num_total_pos is not None:
loss = loss.sum() / num_total_pos
else:
loss = loss.mean()
return loss return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs): def bbox_loss(self, s_bbox, t_bbox, weight_targets=None):
loss = {} # [x,y,w,h]
head_outs = self.student_model.head(stu_fea_list) if weight_targets is not None:
loss_gfl = self.student_model.head.get_loss(head_outs, inputs) loss = paddle.sum(self.loss_bbox(s_bbox, t_bbox) * weight_targets)
loss.update(loss_gfl) avg_factor = weight_targets.sum()
total_loss = paddle.add_n(list(loss.values())) loss = loss / avg_factor
loss.update({'loss': total_loss}) else:
# cwd kd loss loss = paddle.mean(self.loss_bbox(s_bbox, t_bbox))
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss return loss
def forward(self, inputs): def distribution_focal_loss(self,
if self.training: pred_corners,
s_body_feats = self.student_model.backbone(inputs) target_corners,
s_neck_feats = self.student_model.neck(s_body_feats) weight_targets=None):
target_corners_label = F.softmax(target_corners, axis=-1)
loss_dfl = F.cross_entropy(
pred_corners,
target_corners_label,
soft_label=True,
reduction='none')
loss_dfl = loss_dfl.sum(1)
if weight_targets is not None:
loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1]))
loss_dfl = loss_dfl.sum(-1) / weight_targets.sum()
else:
loss_dfl = loss_dfl.mean(-1)
return loss_dfl / 4.0 # 4 direction
with paddle.no_grad(): def forward(self, teacher_model, student_model):
t_body_feats = self.teacher_model.backbone(inputs) teacher_distill_pairs = teacher_model.yolo_head.distill_pairs
t_neck_feats = self.teacher_model.neck(t_body_feats) student_distill_pairs = student_model.yolo_head.distill_pairs
if self.logits_distill and self.loss_weight_logits > 0:
distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], []
if self.arch == "RetinaNet": distill_cls_loss.append(
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats, self.quality_focal_loss(
inputs) student_distill_pairs['pred_cls_scores'].reshape(
elif self.arch == "GFL": (-1, student_distill_pairs['pred_cls_scores'].shape[-1]
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs) )),
else: teacher_distill_pairs['pred_cls_scores'].detach().reshape(
raise ValueError(f"unsupported arch {self.arch}") (-1, teacher_distill_pairs['pred_cls_scores'].shape[-1]
return loss )),
num_total_pos=student_distill_pairs['pos_num'],
use_sigmoid=False))
distill_bbox_loss.append(
self.bbox_loss(student_distill_pairs['pred_bboxes_pos'],
teacher_distill_pairs['pred_bboxes_pos'].detach(),
weight_targets=student_distill_pairs['bbox_weight']
) if 'pred_bboxes_pos' in student_distill_pairs and \
'pred_bboxes_pos' in teacher_distill_pairs and \
'bbox_weight' in student_distill_pairs
else paddle.zeros([1]))
distill_dfl_loss.append(
self.distribution_focal_loss(
student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])),
teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \
weight_targets=student_distill_pairs['bbox_weight']
) if 'pred_dist_pos' in student_distill_pairs and \
'pred_dist_pos' in teacher_distill_pairs and \
'bbox_weight' in student_distill_pairs
else paddle.zeros([1]))
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_bbox_loss = paddle.add_n(distill_bbox_loss)
distill_dfl_loss = paddle.add_n(distill_dfl_loss)
logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight
else: else:
body_feats = self.student_model.backbone(inputs) logits_loss = paddle.zeros([1])
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats) if self.feat_distill and self.loss_weight_feat > 0:
if self.arch == "RetinaNet": feat_loss_list = []
bbox, bbox_num = self.student_model.head.post_process( inputs = student_model.inputs
head_outs, inputs['im_shape'], inputs['scale_factor']) assert 'gt_bbox' in inputs
return {'bbox': bbox, 'bbox_num': bbox_num} assert self.feat_distill_place in student_distill_pairs
elif self.arch == "GFL": assert self.feat_distill_place in teacher_distill_pairs
bbox_pred, bbox_num = head_outs stu_feats = student_distill_pairs[self.feat_distill_place]
output = {'bbox': bbox_pred, 'bbox_num': bbox_num} tea_feats = teacher_distill_pairs[self.feat_distill_place]
return output for i, loss_module in enumerate(self.distill_feat_loss_modules):
else: feat_loss_list.append(
raise ValueError(f"unsupported arch {self.arch}") loss_module(stu_feats[i], tea_feats[i], inputs))
feat_loss = paddle.add_n(feat_loss_list)
else:
feat_loss = paddle.zeros([1])
student_model.yolo_head.distill_pairs.clear()
teacher_model.yolo_head.distill_pairs.clear()
return logits_loss * self.loss_weight_logits, feat_loss * self.loss_weight_feat
@register @register
class ChannelWiseDivergence(nn.Layer): class CWDFeatureLoss(nn.Layer):
def __init__(self, student_channels, teacher_channels, tau=1.0, weight=1.0): def __init__(self,
super(ChannelWiseDivergence, self).__init__() student_channels,
teacher_channels,
normalize=False,
tau=1.0,
weight=1.0):
super(CWDFeatureLoss, self).__init__()
self.normalize = normalize
self.tau = tau self.tau = tau
self.loss_weight = weight self.loss_weight = weight
...@@ -325,20 +421,23 @@ class ChannelWiseDivergence(nn.Layer): ...@@ -325,20 +421,23 @@ class ChannelWiseDivergence(nn.Layer):
else: else:
self.align = None self.align = None
def distill_softmax(self, x, t): def distill_softmax(self, x, tau):
_, _, w, h = paddle.shape(x) _, _, w, h = paddle.shape(x)
x = paddle.reshape(x, [-1, w * h]) x = paddle.reshape(x, [-1, w * h])
x /= t x /= tau
return F.softmax(x, axis=1) return F.softmax(x, axis=1)
def forward(self, preds_s, preds_t): def forward(self, preds_s, preds_t, inputs):
assert preds_s.shape[-2:] == preds_t.shape[ assert preds_s.shape[-2:] == preds_t.shape[-2:]
-2:], 'the output dim of teacher and student differ' N, C, H, W = preds_s.shape
N, C, W, H = preds_s.shape
eps = 1e-5 eps = 1e-5
if self.align is not None: if self.align is not None:
preds_s = self.align(preds_s) preds_s = self.align(preds_s)
if self.normalize:
preds_s = feature_norm(preds_s)
preds_t = feature_norm(preds_t)
softmax_pred_s = self.distill_softmax(preds_s, self.tau) softmax_pred_s = self.distill_softmax(preds_s, self.tau)
softmax_pred_t = self.distill_softmax(preds_t, self.tau) softmax_pred_t = self.distill_softmax(preds_t, self.tau)
...@@ -347,73 +446,16 @@ class ChannelWiseDivergence(nn.Layer): ...@@ -347,73 +446,16 @@ class ChannelWiseDivergence(nn.Layer):
return self.loss_weight * loss / (C * N) return self.loss_weight * loss / (C * N)
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
super(DistillYOLOv3Loss, self).__init__()
self.weight = weight
def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
loss_w = paddle.abs(sw - tw)
loss_h = paddle.abs(sh - th)
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(self, scls, tcls, tobj):
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
return weighted_loss
def obj_loss(self, sobj, tobj):
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = paddle.mean(
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
student_distill_pairs = student_model.yolo_head.loss.distill_pairs
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
distill_reg_loss.append(
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
distill_cls_loss.append(
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
distill_reg_loss = paddle.add_n(distill_reg_loss)
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_obj_loss = paddle.add_n(distill_obj_loss)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.weight
return loss
def parameter_init(mode="kaiming", value=0.):
if mode == "kaiming":
weight_attr = paddle.nn.initializer.KaimingUniform()
elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
weight_init = ParamAttr(initializer=weight_attr)
return weight_init
@register @register
class FGDFeatureLoss(nn.Layer): class FGDFeatureLoss(nn.Layer):
""" """
Focal and Global Knowledge Distillation for Detectors
The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
Paddle version of `Focal and Global Knowledge Distillation for Detectors`
Args: Args:
student_channels(int): The number of channels in the student's FPN feature map. Default to 256. student_channels (int): The number of channels in the student's FPN feature map. Default to 256.
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256.
normalize (bool): Whether to normalize the feature maps.
temp (float, optional): The temperature coefficient. Defaults to 0.5. temp (float, optional): The temperature coefficient. Defaults to 0.5.
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005
...@@ -422,20 +464,23 @@ class FGDFeatureLoss(nn.Layer): ...@@ -422,20 +464,23 @@ class FGDFeatureLoss(nn.Layer):
""" """
def __init__(self, def __init__(self,
student_channels=256, student_channels,
teacher_channels=256, teacher_channels,
normalize=False,
loss_weight=1.0,
temp=0.5, temp=0.5,
alpha_fgd=0.001, alpha_fgd=0.001,
beta_fgd=0.0005, beta_fgd=0.0005,
gamma_fgd=0.001, gamma_fgd=0.001,
lambda_fgd=0.000005): lambda_fgd=0.000005):
super(FGDFeatureLoss, self).__init__() super(FGDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
self.temp = temp self.temp = temp
self.alpha_fgd = alpha_fgd self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd self.lambda_fgd = lambda_fgd
kaiming_init = parameter_init("kaiming") kaiming_init = parameter_init("kaiming")
zeros_init = parameter_init("constant", 0.0) zeros_init = parameter_init("constant", 0.0)
...@@ -486,7 +531,6 @@ class FGDFeatureLoss(nn.Layer): ...@@ -486,7 +531,6 @@ class FGDFeatureLoss(nn.Layer):
def spatial_channel_attention(self, x, t=0.5): def spatial_channel_attention(self, x, t=0.5):
shape = paddle.shape(x) shape = paddle.shape(x)
N, C, H, W = shape N, C, H, W = shape
_f = paddle.abs(x) _f = paddle.abs(x)
spatial_map = paddle.reshape( spatial_map = paddle.reshape(
paddle.mean( paddle.mean(
...@@ -515,7 +559,6 @@ class FGDFeatureLoss(nn.Layer): ...@@ -515,7 +559,6 @@ class FGDFeatureLoss(nn.Layer):
context_mask = context_mask.unsqueeze(-1) context_mask = context_mask.unsqueeze(-1)
context = paddle.matmul(x_copy, context_mask) context = paddle.matmul(x_copy, context_mask)
context = paddle.reshape(context, [batch, channel, 1, 1]) context = paddle.reshape(context, [batch, channel, 1, 1])
return context return context
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att,
...@@ -525,44 +568,35 @@ class FGDFeatureLoss(nn.Layer): ...@@ -525,44 +568,35 @@ class FGDFeatureLoss(nn.Layer):
mask_loss = _func(stu_channel_att, tea_channel_att) + _func( mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
stu_spatial_att, tea_spatial_att) stu_spatial_att, tea_spatial_att)
return mask_loss return mask_loss
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg,
tea_channel_att, tea_spatial_att): tea_channel_att, tea_spatial_att):
mask_fg = mask_fg.unsqueeze(axis=1)
Mask_fg = Mask_fg.unsqueeze(axis=1) mask_bg = mask_bg.unsqueeze(axis=1)
Mask_bg = Mask_bg.unsqueeze(axis=1) tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) tea_spatial_att = tea_spatial_att.unsqueeze(axis=1)
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att))
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg))
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att))
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att))
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg))
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg)
return fg_loss, bg_loss return fg_loss, bg_loss
def relation_loss(self, stu_feature, tea_feature): def relation_loss(self, stu_feature, tea_feature):
context_s = self.spatial_pool(stu_feature, "student") context_s = self.spatial_pool(stu_feature, "student")
context_t = self.spatial_pool(tea_feature, "teacher") context_t = self.spatial_pool(tea_feature, "teacher")
out_s = stu_feature + self.stu_conv_block(context_s) out_s = stu_feature + self.stu_conv_block(context_s)
out_t = tea_feature + self.tea_conv_block(context_t) out_t = tea_feature + self.tea_conv_block(context_t)
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
return rela_loss return rela_loss
def mask_value(self, mask, xl, xr, yl, yr, value): def mask_value(self, mask, xl, xr, yl, yr, value):
...@@ -570,21 +604,12 @@ class FGDFeatureLoss(nn.Layer): ...@@ -570,21 +604,12 @@ class FGDFeatureLoss(nn.Layer):
return mask return mask
def forward(self, stu_feature, tea_feature, inputs): def forward(self, stu_feature, tea_feature, inputs):
"""Forward function. assert stu_feature.shape[-2:] == stu_feature.shape[-2:]
Args: assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys()
stu_feature(Tensor): Bs*C*H*W, student's feature map
tea_feature(Tensor): Bs*C*H*W, teacher's feature map
inputs: The inputs with gt bbox and input shape info.
"""
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.'
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
gt_bboxes = inputs['gt_bbox'] gt_bboxes = inputs['gt_bbox']
ins_shape = [ ins_shape = [
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
] ]
index_gt = [] index_gt = []
for i in range(len(gt_bboxes)): for i in range(len(gt_bboxes)):
if gt_bboxes[i].size > 2: if gt_bboxes[i].size > 2:
...@@ -592,35 +617,41 @@ class FGDFeatureLoss(nn.Layer): ...@@ -592,35 +617,41 @@ class FGDFeatureLoss(nn.Layer):
# only distill feature with labeled GTbox # only distill feature with labeled GTbox
if len(index_gt) != len(gt_bboxes): if len(index_gt) != len(gt_bboxes):
index_gt_t = paddle.to_tensor(index_gt) index_gt_t = paddle.to_tensor(index_gt)
preds_S = paddle.index_select(preds_S, index_gt_t) stu_feature = paddle.index_select(stu_feature, index_gt_t)
preds_T = paddle.index_select(preds_T, index_gt_t) tea_feature = paddle.index_select(tea_feature, index_gt_t)
ins_shape = [ins_shape[c] for c in index_gt] ins_shape = [ins_shape[c] for c in index_gt]
gt_bboxes = [gt_bboxes[c] for c in index_gt] gt_bboxes = [gt_bboxes[c] for c in index_gt]
assert len(gt_bboxes) == preds_T.shape[ assert len(gt_bboxes) == tea_feature.shape[0]
0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]."
if self.align is not None: if self.align is not None:
stu_feature = self.align(stu_feature) stu_feature = self.align(stu_feature)
N, C, H, W = stu_feature.shape if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
tea_spatial_att, tea_channel_att = self.spatial_channel_attention( tea_spatial_att, tea_channel_att = self.spatial_channel_attention(
tea_feature, self.temp) tea_feature, self.temp)
stu_spatial_att, stu_channel_att = self.spatial_channel_attention( stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
stu_feature, self.temp) stu_feature, self.temp)
Mask_fg = paddle.zeros(tea_spatial_att.shape) mask_fg = paddle.zeros(tea_spatial_att.shape)
Mask_bg = paddle.ones_like(tea_spatial_att) mask_bg = paddle.ones_like(tea_spatial_att)
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) one_tmp = paddle.ones([*tea_spatial_att.shape[1:]])
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]])
Mask_fg.stop_gradient = True mask_fg.stop_gradient = True
Mask_bg.stop_gradient = True mask_bg.stop_gradient = True
one_tmp.stop_gradient = True one_tmp.stop_gradient = True
zero_tmp.stop_gradient = True zero_tmp.stop_gradient = True
wmin, wmax, hmin, hmax, area = [], [], [], [], [] wmin, wmax, hmin, hmax = [], [], [], []
if gt_bboxes.shape[1] == 0:
loss = self.relation_loss(stu_feature, tea_feature)
return self.lambda_fgd * loss
N, _, H, W = stu_feature.shape
for i in range(N): for i in range(N):
tmp_box = paddle.ones_like(gt_bboxes[i]) tmp_box = paddle.ones_like(gt_bboxes[i])
tmp_box.stop_gradient = True tmp_box.stop_gradient = True
...@@ -633,7 +664,6 @@ class FGDFeatureLoss(nn.Layer): ...@@ -633,7 +664,6 @@ class FGDFeatureLoss(nn.Layer):
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") ones = paddle.ones_like(tmp_box[:, 2], dtype="int32")
zero.stop_gradient = True zero.stop_gradient = True
ones.stop_gradient = True ones.stop_gradient = True
wmin.append( wmin.append(
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero))
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32"))
...@@ -646,164 +676,217 @@ class FGDFeatureLoss(nn.Layer): ...@@ -646,164 +676,217 @@ class FGDFeatureLoss(nn.Layer):
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
for j in range(len(gt_bboxes[i])): for j in range(len(gt_bboxes[i])):
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], if gt_bboxes[i][j].sum() > 0:
hmax[i][j] + 1, wmin[i][j], mask_fg[i] = self.mask_value(
wmax[i][j] + 1, area_recip[0][j]) mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j])
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
if paddle.sum(Mask_bg[i]): if paddle.sum(mask_bg[i]):
Mask_bg[i] /= paddle.sum(Mask_bg[i]) mask_bg[i] /= paddle.sum(mask_bg[i])
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, mask_fg,
Mask_bg, tea_channel_att, mask_bg, tea_channel_att,
tea_spatial_att) tea_spatial_att)
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, mask_loss = self.mask_loss(stu_channel_att, tea_channel_att,
stu_spatial_att, tea_spatial_att) stu_spatial_att, tea_spatial_att)
rela_loss = self.relation_loss(stu_feature, tea_feature) rela_loss = self.relation_loss(stu_feature, tea_feature)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss * self.loss_weight
return loss
class LDDistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__()
self.student_model = create(cfg.architecture)
logger.debug('Load student model pretrain_weights:{}'.format(
cfg.pretrain_weights))
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
slim_cfg = load_config(slim_cfg) #rewrite student cfg
self.teacher_model = create(slim_cfg.architecture)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters():
param.trainable = False
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
#student_loss = self.student_model(inputs)
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
@register @register
class KnowledgeDistillationKLDivLoss(nn.Layer): class PKDFeatureLoss(nn.Layer):
"""Loss function for knowledge distilling using KL divergence. """
PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient.
Args: Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`. loss_weight (float): Weight of loss. Defaults to 1.0.
loss_weight (float): Loss weight of current loss. resize_stu (bool): If True, we'll down/up sample the features of the
T (int): Temperature for distillation. student model to the spatial size of those of the teacher model if
their spatial sizes are different. And vice versa. Defaults to
True.
""" """
def __init__(self, reduction='mean', loss_weight=1.0, T=10): def __init__(self,
super(KnowledgeDistillationKLDivLoss, self).__init__() student_channels=256,
assert reduction in ('none', 'mean', 'sum') teacher_channels=256,
assert T >= 1 normalize=True,
self.reduction = reduction loss_weight=1.0,
resize_stu=True):
super(PKDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.T = T self.resize_stu = resize_stu
def knowledge_distillation_kl_div_loss(self, def forward(self, stu_feature, tea_feature, inputs):
pred, size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:]
soft_label, if size_s[0] != size_t[0]:
T, if self.resize_stu:
detach_target=True): stu_feature = F.interpolate(
r"""Loss function for knowledge distilling using KL divergence. stu_feature, size_t, mode='bilinear')
else:
tea_feature = F.interpolate(
tea_feature, size_s, mode='bilinear')
assert stu_feature.shape == tea_feature.shape
Args: if self.normalize:
pred (Tensor): Predicted logits with shape (N, n + 1). stu_feature = feature_norm(stu_feature)
soft_label (Tensor): Target logits with shape (N, N + 1). tea_feature = feature_norm(tea_feature)
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
Returns: loss = F.mse_loss(stu_feature, tea_feature) / 2
torch.Tensor: Loss tensor with shape (N,). return loss * self.loss_weight
"""
assert pred.shape == soft_label.shape
target = F.softmax(soft_label / T, axis=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div( @register
F.log_softmax( class MimicFeatureLoss(nn.Layer):
pred / T, axis=1), target, reduction='none').mean(1) * (T * T) def __init__(self,
student_channels=256,
teacher_channels=256,
normalize=True,
loss_weight=1.0):
super(MimicFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
self.mse_loss = nn.MSELoss()
return kd_loss if student_channels != teacher_channels:
self.align = nn.Conv2D(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None
def forward(self, def forward(self, stu_feature, tea_feature, inputs):
pred, if self.align is not None:
soft_label, stu_feature = self.align(stu_feature)
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args: if self.normalize:
pred (Tensor): Predicted logits with shape (N, n + 1). stu_feature = feature_norm(stu_feature)
soft_label (Tensor): Target logits with shape (N, N + 1). tea_feature = feature_norm(tea_feature)
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override loss = self.mse_loss(stu_feature, tea_feature)
if reduction_override else self.reduction) return loss * self.loss_weight
loss_kd_out = self.knowledge_distillation_kl_div_loss(
pred, soft_label, T=self.T)
if weight is not None: @register
loss_kd_out = weight * loss_kd_out class MGDFeatureLoss(nn.Layer):
def __init__(self,
student_channels=256,
teacher_channels=256,
normalize=True,
loss_weight=1.0,
loss_func='mse'):
super(MGDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
assert loss_func in ['mse', 'ssim']
self.loss_func = loss_func
self.mse_loss = nn.MSELoss(reduction='sum')
self.ssim_loss = SSIM(11)
if avg_factor is None: kaiming_init = parameter_init("kaiming")
if reduction == 'none': if student_channels != teacher_channels:
loss = loss_kd_out self.align = nn.Conv2D(
elif reduction == 'mean': student_channels,
loss = loss_kd_out.mean() teacher_channels,
elif reduction == 'sum': kernel_size=1,
loss = loss_kd_out.sum() stride=1,
padding=0,
weight_attr=kaiming_init,
bias_attr=False)
else: else:
# if reduction is mean, then average the loss by avg_factor self.align = None
if reduction == 'mean':
loss = loss_kd_out.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
loss_kd = self.loss_weight * loss self.generation = nn.Sequential(
nn.Conv2D(
teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2D(
teacher_channels, teacher_channels, kernel_size=3, padding=1))
return loss_kd def forward(self, stu_feature, tea_feature, inputs):
N = stu_feature.shape[0]
if self.align is not None:
stu_feature = self.align(stu_feature)
stu_feature = self.generation(stu_feature)
if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
if self.loss_func == 'mse':
loss = self.mse_loss(stu_feature, tea_feature) / N
elif self.loss_func == 'ssim':
ssim_loss = self.ssim_loss(stu_feature, tea_feature)
loss = paddle.clip((1 - ssim_loss) / 2, 0, 1)
else:
raise ValueError
return loss * self.loss_weight
class SSIM(nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = self.create_window(window_size, self.channel)
def gaussian(self, window_size, sigma):
gauss = paddle.to_tensor([
math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(self, window_size, channel):
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = _2D_window.expand([channel, 1, window_size, window_size])
return window
def _ssim(self, img1, img2, window, window_size, channel,
size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2,
groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2,
groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean([1, 2, 3])
def forward(self, img1, img2):
channel = img1.shape[1]
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = self.create_window(self.window_size, channel)
self.window = window
self.channel = channel
return self._ssim(img1, img2, window, self.window_size, channel,
self.size_average)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'DistillModel',
'FGDDistillModel',
'CWDDistillModel',
'LDDistillModel',
'PPYOLOEDistillModel',
]
@register
class DistillModel(nn.Layer):
"""
Build common distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.arch = cfg.architecture
self.stu_cfg = cfg
self.student_model = create(self.stu_cfg.architecture)
if 'pretrain_weights' in self.stu_cfg and self.stu_cfg.pretrain_weights:
stu_pretrain = self.stu_cfg.pretrain_weights
else:
stu_pretrain = None
slim_cfg = load_config(slim_cfg)
self.tea_cfg = slim_cfg
self.teacher_model = create(self.tea_cfg.architecture)
if 'pretrain_weights' in self.tea_cfg and self.tea_cfg.pretrain_weights:
tea_pretrain = self.tea_cfg.pretrain_weights
else:
tea_pretrain = None
self.distill_cfg = slim_cfg
# load pretrain weights
self.is_inherit = False
if stu_pretrain:
if self.is_inherit and tea_pretrain:
load_pretrain_weight(self.student_model, tea_pretrain)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
logger.info("Student model has loaded pretrain weights!")
if tea_pretrain:
load_pretrain_weight(self.teacher_model, tea_pretrain)
logger.info("Teacher model has loaded pretrain weights!")
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
self.distill_loss = self.build_loss(self.distill_cfg)
def build_loss(self, distill_cfg):
if 'distill_loss' in distill_cfg and distill_cfg.distill_loss:
return create(distill_cfg.distill_loss)
else:
return None
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
loss = self.distill_loss(self.teacher_model, self.student_model)
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss']
student_loss['loss'] += student_loss['distill_loss']
return student_loss
else:
return self.student_model(inputs)
@register
class FGDDistillModel(DistillModel):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['RetinaNet', 'PicoDet'
], 'Unsupported arch: {}'.format(self.arch)
self.is_inherit = True
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.distill_loss):
loss_dict[k] = self.distill_loss[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
@register
class CWDDistillModel(DistillModel):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(CWDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL', 'RetinaNet'], 'Unsupported arch: {}'.format(
self.arch)
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
loss = self.student_model.head(stu_fea_list, inputs)
distill_loss = {}
for idx, k in enumerate(self.loss_dic):
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
loss['loss'] += distill_loss[k]
loss[k] = distill_loss[k]
return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
loss = {}
head_outs = self.student_model.head(stu_fea_list)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
loss.update(loss_gfl)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
if self.arch == "RetinaNet":
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
inputs)
elif self.arch == "GFL":
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
else:
raise ValueError(f"unsupported arch {self.arch}")
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "GFL":
bbox_pred, bbox_num = head_outs
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
raise ValueError(f"unsupported arch {self.arch}")
@register
class LDDistillModel(DistillModel):
"""
Build LD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL'], 'Unsupported arch: {}'.format(self.arch)
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
@register
class PPYOLOEDistillModel(DistillModel):
"""
Build PPYOLOE distill model, only used in PPYOLOE
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(PPYOLOEDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['PPYOLOE'], 'Unsupported arch: {}'.format(
self.arch)
def forward(self, inputs, alpha=0.125):
if self.training:
if hasattr(self.teacher_model.yolo_head, "assigned_labels"):
self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \
self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive
delattr(self.teacher_model.yolo_head, "assigned_labels")
delattr(self.teacher_model.yolo_head, "assigned_bboxes")
delattr(self.teacher_model.yolo_head, "assigned_scores")
delattr(self.teacher_model.yolo_head, "mask_positive")
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
logits_loss, feat_loss = self.distill_loss(self.teacher_model,
self.student_model)
det_total_loss = student_loss['loss']
total_loss = alpha * (det_total_loss + logits_loss + feat_loss)
student_loss['loss'] = total_loss
student_loss['det_loss'] = det_total_loss
student_loss['logits_loss'] = logits_loss
student_loss['feat_loss'] = feat_loss
return student_loss
else:
return self.student_model(inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册