未验证 提交 b2e0acf1 编写于 作者: F Feng Ni 提交者: GitHub

Add PPYOLOE distillation (#7671)

* add ppyoloe distill with soft loss and feature loss

* refine distill codes

* fix configs and docs

* clean codes

* merge cam, fix export
上级 2b5fd266
# PPYOLOE+ Distillation(PPYOLOE+ 蒸馏)
PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。
## 模型库
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
_BASE_: [
'../ppyoloe_plus_crn_l_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
_BASE_: [
'../ppyoloe_plus_crn_m_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_m_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_m_obj365_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
_BASE_: [
'../ppyoloe_plus_crn_s_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
......@@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https://
|gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) |
|gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) |
## PPYOLOE+模型蒸馏
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
## Citations
```
......
......@@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_
slim: Distill
slim_method: CWD
distill_loss: ChannelWiseDivergence
distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence:
CWDFeatureLoss:
student_channels: 80
teacher_channels: 80
tau: 1.0
......
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 1.0
width_mult: 1.0
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # L -> M
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.0 # L
student_width_mult: 0.75 # M
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 0.67
width_mult: 0.75
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # M -> S
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 0.75 # M
student_width_mult: 0.5 # S
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml',
]
depth_mult: 1.33
width_mult: 1.25
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # X -> L
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.25 # X
student_width_mult: 1.0 # L
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
......@@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c
slim: Distill
slim_method: CWD
distill_loss: ChannelWiseDivergence
distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence:
CWDFeatureLoss:
student_channels: 80
teacher_channels: 80
name: cwdloss
tau: 1.0
weight: 5.0
......@@ -22,13 +22,14 @@ from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head
@register
class PPYOLOE(BaseArch):
__category__ = 'architecture'
__shared__ = ['for_distill']
__inject__ = ['post_process']
def __init__(self,
......@@ -36,6 +37,8 @@ class PPYOLOE(BaseArch):
neck='CustomCSPPAN',
yolo_head='PPYOLOEHead',
post_process='BBoxPostProcess',
for_distill=False,
feat_distill_place='neck_feats',
for_mot=False):
"""
PPYOLOE network, see https://arxiv.org/abs/2203.16250
......@@ -54,6 +57,10 @@ class PPYOLOE(BaseArch):
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
self.for_distill = for_distill
self.feat_distill_place = feat_distill_place
if for_distill:
assert feat_distill_place in ['backbone_feats', 'neck_feats']
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -80,17 +87,31 @@ class PPYOLOE(BaseArch):
if self.training:
yolo_losses = self.yolo_head(neck_feats, self.inputs)
if self.for_distill:
if self.feat_distill_place == 'backbone_feats':
self.yolo_head.distill_pairs['backbone_feats'] = body_feats
elif self.feat_distill_place == 'neck_feats':
self.yolo_head.distill_pairs['neck_feats'] = neck_feats
else:
raise ValueError
return yolo_losses
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
bbox, bbox_num = self.post_process(
bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
bbox, bbox_num = self.yolo_head.post_process(
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output
......@@ -180,15 +201,21 @@ class PPYOLOEWithAuxHead(BaseArch):
aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
bbox, bbox_num = self.post_process(
bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
bbox, bbox_num = self.yolo_head.post_process(
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output
......
......@@ -22,7 +22,7 @@ from ..post_process import JDEBBoxPostProcess
__all__ = ['YOLOv3']
# YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py, especially when use distillation or aux head
@register
......
......@@ -221,4 +221,4 @@ class ATSSAssigner(nn.Layer):
paddle.zeros_like(gather_scores))
assigned_scores *= gather_scores.unsqueeze(-1)
return assigned_labels, assigned_bboxes, assigned_scores
return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
......@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer):
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores
return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
......@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead):
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
......@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead):
pred_bboxes=pred_bboxes.detach() * stride_tensor_list)
else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list,
centers,
......
......@@ -136,7 +136,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner(
anchors,
num_anchors_list,
......@@ -148,7 +148,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
alpha_l = 0.25
else:
if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
......@@ -159,7 +159,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask,
bg_index=self.num_classes)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
......
......@@ -53,7 +53,7 @@ class ESEAttn(nn.Layer):
class PPYOLOEHead(nn.Layer):
__shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv'
'exclude_post_process', 'use_shared_conv', 'for_distill'
]
__inject__ = ['static_assigner', 'assigner', 'nms']
......@@ -81,7 +81,8 @@ class PPYOLOEHead(nn.Layer):
attn_conv='convbn',
exclude_nms=False,
exclude_post_process=False,
use_shared_conv=True):
use_shared_conv=True,
for_distill=False):
super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
......@@ -110,6 +111,7 @@ class PPYOLOEHead(nn.Layer):
self.exclude_nms = exclude_nms
self.exclude_post_process = exclude_post_process
self.use_shared_conv = use_shared_conv
self.for_distill = for_distill
# stem
self.stem_cls = nn.LayerList()
......@@ -135,6 +137,9 @@ class PPYOLOEHead(nn.Layer):
self.proj_conv.skip_quant = True
self._init_weights()
if self.for_distill:
self.distill_pairs = {}
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
......@@ -321,6 +326,10 @@ class PPYOLOEHead(nn.Layer):
loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos,
self.reg_range[0]) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos
self.distill_pairs['pred_dist_pos'] = pred_dist_pos
self.distill_pairs['bbox_weight'] = bbox_weight
else:
loss_l1 = paddle.zeros([1])
loss_iou = paddle.zeros([1])
......@@ -343,7 +352,7 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.static_assigner(
anchors,
num_anchors_list,
......@@ -356,7 +365,7 @@ class PPYOLOEHead(nn.Layer):
else:
if self.sm_use:
# only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
......@@ -368,18 +377,28 @@ class PPYOLOEHead(nn.Layer):
bg_index=self.num_classes)
else:
if aux_pred is None:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
if not hasattr(self, "assigned_labels"):
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
else:
assigned_labels = self.assigned_labels
assigned_bboxes = self.assigned_bboxes
assigned_scores = self.assigned_scores
mask_positive = self.mask_positive
else:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner(
pred_scores_aux.detach(),
pred_bboxes_aux.detach() * stride_tensor,
......@@ -395,12 +414,14 @@ class PPYOLOEHead(nn.Layer):
assign_out_dict = self.get_loss_from_assign(
pred_scores, pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
assigned_labels, assigned_bboxes, assigned_scores, mask_positive,
alpha_l)
if aux_pred is not None:
assign_out_dict_aux = self.get_loss_from_assign(
aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
assigned_labels, assigned_bboxes, assigned_scores,
mask_positive, alpha_l)
loss = {}
for key in assign_out_dict.keys():
loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
......@@ -411,7 +432,7 @@ class PPYOLOEHead(nn.Layer):
def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
anchor_points_s, assigned_labels, assigned_bboxes,
assigned_scores, alpha_l):
assigned_scores, mask_positive, alpha_l):
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels,
......@@ -428,6 +449,15 @@ class PPYOLOEHead(nn.Layer):
assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
loss_cls /= assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_cls_scores'] = pred_scores
self.distill_pairs['pos_num'] = assigned_scores_sum
self.distill_pairs['assigned_scores'] = assigned_scores
self.distill_pairs['mask_positive'] = mask_positive
one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1]
self.distill_pairs['target_labels'] = one_hot_label
loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores,
......@@ -450,7 +480,8 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes *= stride_tensor
if self.exclude_post_process:
return paddle.concat(
[pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1), None
[pred_bboxes, pred_scores.transpose([0, 2, 1])],
axis=-1), None, None
else:
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
......@@ -460,9 +491,10 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes /= scale_factor
if self.exclude_nms:
# `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores
return pred_bboxes, pred_scores, None
else:
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, pred_scores)
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
pred_scores)
return bbox_pred, bbox_num, before_nms_indexes
......
......@@ -258,7 +258,7 @@ class PPYOLOERHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner(
anchor_points,
stride_tensor,
......@@ -271,7 +271,7 @@ class PPYOLOERHead(nn.Layer):
pred_bboxes.detach()
)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach(),
......
......@@ -293,7 +293,7 @@ class TOODHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
......@@ -302,7 +302,7 @@ class TOODHead(nn.Layer):
bg_index=self.num_classes)
alpha_l = 0.25
else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
bbox_center(anchors),
......
......@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import distill_loss
from . import distill_model
from . import ofa
from . import prune
from . import quant
from . import distill
from . import unstructured_prune
from .distill_loss import *
from .distill_model import *
from .ofa import *
from .prune import *
from .quant import *
from .distill import *
from .unstructured_prune import *
from .ofa import *
import yaml
from ppdet.core.workspace import load_config
......@@ -45,7 +48,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "CWD":
model = CWDDistillModel(cfg, slim_cfg)
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "PPYOLOEDistill":
model = PPYOLOEDistillModel(cfg, slim_cfg)
else:
# common distillation model
model = DistillModel(cfg, slim_cfg)
cfg['model'] = model
cfg['slim_type'] = cfg.slim
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,302 +16,398 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config
from ppdet.core.workspace import register, create
from ppdet.modeling import ops
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.modeling.losses.iou_loss import GIoULoss
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'DistillYOLOv3Loss',
'KnowledgeDistillationKLDivLoss',
'DistillPPYOLOELoss',
'FGDFeatureLoss',
'CWDFeatureLoss',
'PKDFeatureLoss',
'MGDFeatureLoss',
]
class DistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.student_model = create(cfg.architecture)
logger.debug('Load student model pretrain_weights:{}'.format(
cfg.pretrain_weights))
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
def parameter_init(mode="kaiming", value=0.):
if mode == "kaiming":
weight_attr = paddle.nn.initializer.KaimingUniform()
elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
slim_cfg = load_config(slim_cfg)
weight_init = ParamAttr(initializer=weight_attr)
return weight_init
self.teacher_model = create(slim_cfg.architecture)
self.distill_loss = create(slim_cfg.distill_loss)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters():
param.trainable = False
def feature_norm(feat):
# Normalize the feature maps to have zero mean and unit variances.
assert len(feat.shape) == 4
N, C, H, W = feat.shape
feat = feat.transpose([1, 0, 2, 3]).reshape([C, -1])
mean = feat.mean(axis=-1, keepdim=True)
std = feat.std(axis=-1, keepdim=True)
feat = (feat - mean) / (std + 1e-6)
return feat.reshape([C, N, H, W]).transpose([1, 0, 2, 3])
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
teacher_loss = self.teacher_model(inputs)
student_loss = self.student_model(inputs)
loss = self.distill_loss(self.teacher_model, self.student_model)
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss']
student_loss['loss'] += student_loss['distill_loss']
return student_loss
else:
return self.student_model(inputs)
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
super(DistillYOLOv3Loss, self).__init__()
self.loss_weight = weight
def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
loss_w = paddle.abs(sw - tw)
loss_h = paddle.abs(sh - th)
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(self, scls, tcls, tobj):
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
return weighted_loss
def obj_loss(self, sobj, tobj):
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = paddle.mean(
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
student_distill_pairs = student_model.yolo_head.loss.distill_pairs
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
distill_reg_loss.append(
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
distill_cls_loss.append(
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
distill_reg_loss = paddle.add_n(distill_reg_loss)
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_obj_loss = paddle.add_n(distill_obj_loss)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.loss_weight
return loss
@register
class KnowledgeDistillationKLDivLoss(nn.Layer):
"""Loss function for knowledge distilling using KL divergence.
class FGDDistillModel(nn.Layer):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__()
self.is_inherit = True
# build student model before load slim config
self.student_model = create(cfg.architecture)
self.arch = cfg.architecture
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg
tea_pretrain = cfg['pretrain_weights']
self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
if 'pretrain_weights' in cfg and stu_pretrain:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.student_model,
self.teacher_cfg.pretrain_weights)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights)
self.fgd_loss_dic = self.build_loss(
self.loss_cfg.distill_loss,
name_list=self.loss_cfg['distill_loss_name'])
def build_loss(self,
cfg,
name_list=[
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0'
]):
loss_func = dict()
for idx, k in enumerate(name_list):
loss_func[k] = create(cfg)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.fgd_loss_dic):
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(KnowledgeDistillationKLDivLoss, self).__init__()
assert reduction in ('none', 'mean', 'sum')
assert T >= 1
self.reduction = reduction
self.loss_weight = loss_weight
self.T = T
def knowledge_distillation_kl_div_loss(self,
pred,
soft_label,
T,
detach_target=True):
r"""Loss function for knowledge distilling using KL divergence.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
"""
assert pred.shape == soft_label.shape
target = F.softmax(soft_label / T, axis=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div(
F.log_softmax(
pred / T, axis=1), target, reduction='none').mean(1) * (T * T)
return kd_loss
def forward(self,
pred,
soft_label,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss_kd_out = self.knowledge_distillation_kl_div_loss(
pred, soft_label, T=self.T)
if weight is not None:
loss_kd_out = weight * loss_kd_out
if avg_factor is None:
if reduction == 'none':
loss = loss_kd_out
elif reduction == 'mean':
loss = loss_kd_out.mean()
elif reduction == 'sum':
loss = loss_kd_out.sum()
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss_kd_out.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
loss_kd = self.loss_weight * loss
return loss_kd
class CWDDistillModel(nn.Layer):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(CWDDistillModel, self).__init__()
self.is_inherit = False
# build student model before load slim config
self.student_model = create(cfg.architecture)
self.arch = cfg.architecture
if self.arch not in ['GFL', 'RetinaNet']:
raise ValueError(
f"The arch can only be one of ['GFL', 'RetinaNet'], but received {self.arch}"
)
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg
tea_pretrain = cfg['pretrain_weights']
self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
if 'pretrain_weights' in cfg and stu_pretrain:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.student_model,
self.teacher_cfg.pretrain_weights)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights)
self.loss_dic = self.build_loss(
self.loss_cfg.distill_loss,
name_list=self.loss_cfg['distill_loss_name'])
def build_loss(self,
cfg,
name_list=[
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0'
]):
loss_func = dict()
for idx, k in enumerate(name_list):
loss_func[k] = create(cfg)
return loss_func
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
loss = self.student_model.head(stu_fea_list, inputs)
distill_loss = {}
# cwd kd loss
for idx, k in enumerate(self.loss_dic):
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
loss['loss'] += distill_loss[k]
loss[k] = distill_loss[k]
@register
class DistillPPYOLOELoss(nn.Layer):
def __init__(
self,
loss_weight={'logits': 4.0,
'feat': 1.0},
logits_distill=True,
logits_loss_weight={'class': 1.0,
'iou': 2.5,
'dfl': 0.5},
feat_distill=True,
feat_distiller='fgd',
feat_distill_place='neck_feats',
teacher_width_mult=1.0, # L
student_width_mult=0.75, # M
feat_out_channels=[768, 384, 192]):
super(DistillPPYOLOELoss, self).__init__()
self.loss_weight_logits = loss_weight['logits']
self.loss_weight_feat = loss_weight['feat']
self.logits_distill = logits_distill
self.feat_distill = feat_distill
if logits_distill and self.loss_weight_logits > 0:
self.bbox_loss_weight = logits_loss_weight['iou']
self.dfl_loss_weight = logits_loss_weight['dfl']
self.qfl_loss_weight = logits_loss_weight['class']
self.loss_bbox = GIoULoss()
if feat_distill and self.loss_weight_feat > 0:
assert feat_distiller in ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
assert feat_distill_place in ['backbone_feats', 'neck_feats']
self.feat_distill_place = feat_distill_place
self.t_channel_list = [
int(c * teacher_width_mult) for c in feat_out_channels
]
self.s_channel_list = [
int(c * student_width_mult) for c in feat_out_channels
]
self.distill_feat_loss_modules = []
for i in range(len(feat_out_channels)):
if feat_distiller == 'cwd':
feat_loss_module = CWDFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True)
elif feat_distiller == 'fgd':
feat_loss_module = FGDFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True,
alpha_fgd=0.00001,
beta_fgd=0.000005,
gamma_fgd=0.00001,
lambda_fgd=0.00000005)
elif feat_distiller == 'pkd':
feat_loss_module = PKDFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True,
resize_stu=True)
elif feat_distiller == 'mgd':
feat_loss_module = MGDFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True,
loss_func='ssim')
elif feat_distiller == 'mimic':
feat_loss_module = MimicFeatureLoss(
student_channels=self.s_channel_list[i],
teacher_channels=self.t_channel_list[i],
normalize=True)
else:
raise ValueError
self.distill_feat_loss_modules.append(feat_loss_module)
def quality_focal_loss(self,
pred_logits,
soft_target_logits,
beta=2.0,
use_sigmoid=False,
num_total_pos=None):
if use_sigmoid:
func = F.binary_cross_entropy_with_logits
soft_target = F.sigmoid(soft_target_logits)
pred_sigmoid = F.sigmoid(pred_logits)
preds = pred_logits
else:
func = F.binary_cross_entropy
soft_target = soft_target_logits
pred_sigmoid = pred_logits
preds = pred_sigmoid
scale_factor = pred_sigmoid - soft_target
loss = func(
preds, soft_target, reduction='none') * scale_factor.abs().pow(beta)
loss = loss.sum(1)
if num_total_pos is not None:
loss = loss.sum() / num_total_pos
else:
loss = loss.mean()
return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
loss = {}
head_outs = self.student_model.head(stu_fea_list)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
loss.update(loss_gfl)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
# cwd kd loss
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
def bbox_loss(self, s_bbox, t_bbox, weight_targets=None):
# [x,y,w,h]
if weight_targets is not None:
loss = paddle.sum(self.loss_bbox(s_bbox, t_bbox) * weight_targets)
avg_factor = weight_targets.sum()
loss = loss / avg_factor
else:
loss = paddle.mean(self.loss_bbox(s_bbox, t_bbox))
return loss
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
def distribution_focal_loss(self,
pred_corners,
target_corners,
weight_targets=None):
target_corners_label = F.softmax(target_corners, axis=-1)
loss_dfl = F.cross_entropy(
pred_corners,
target_corners_label,
soft_label=True,
reduction='none')
loss_dfl = loss_dfl.sum(1)
if weight_targets is not None:
loss_dfl = loss_dfl * (weight_targets.expand([-1, 4]).reshape([-1]))
loss_dfl = loss_dfl.sum(-1) / weight_targets.sum()
else:
loss_dfl = loss_dfl.mean(-1)
return loss_dfl / 4.0 # 4 direction
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.distill_pairs
student_distill_pairs = student_model.yolo_head.distill_pairs
if self.logits_distill and self.loss_weight_logits > 0:
distill_bbox_loss, distill_dfl_loss, distill_cls_loss = [], [], []
if self.arch == "RetinaNet":
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
inputs)
elif self.arch == "GFL":
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
else:
raise ValueError(f"unsupported arch {self.arch}")
return loss
distill_cls_loss.append(
self.quality_focal_loss(
student_distill_pairs['pred_cls_scores'].reshape(
(-1, student_distill_pairs['pred_cls_scores'].shape[-1]
)),
teacher_distill_pairs['pred_cls_scores'].detach().reshape(
(-1, teacher_distill_pairs['pred_cls_scores'].shape[-1]
)),
num_total_pos=student_distill_pairs['pos_num'],
use_sigmoid=False))
distill_bbox_loss.append(
self.bbox_loss(student_distill_pairs['pred_bboxes_pos'],
teacher_distill_pairs['pred_bboxes_pos'].detach(),
weight_targets=student_distill_pairs['bbox_weight']
) if 'pred_bboxes_pos' in student_distill_pairs and \
'pred_bboxes_pos' in teacher_distill_pairs and \
'bbox_weight' in student_distill_pairs
else paddle.zeros([1]))
distill_dfl_loss.append(
self.distribution_focal_loss(
student_distill_pairs['pred_dist_pos'].reshape((-1, student_distill_pairs['pred_dist_pos'].shape[-1])),
teacher_distill_pairs['pred_dist_pos'].detach().reshape((-1, teacher_distill_pairs['pred_dist_pos'].shape[-1])), \
weight_targets=student_distill_pairs['bbox_weight']
) if 'pred_dist_pos' in student_distill_pairs and \
'pred_dist_pos' in teacher_distill_pairs and \
'bbox_weight' in student_distill_pairs
else paddle.zeros([1]))
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_bbox_loss = paddle.add_n(distill_bbox_loss)
distill_dfl_loss = paddle.add_n(distill_dfl_loss)
logits_loss = distill_bbox_loss * self.bbox_loss_weight + distill_cls_loss * self.qfl_loss_weight + distill_dfl_loss * self.dfl_loss_weight
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "GFL":
bbox_pred, bbox_num = head_outs
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
raise ValueError(f"unsupported arch {self.arch}")
logits_loss = paddle.zeros([1])
if self.feat_distill and self.loss_weight_feat > 0:
feat_loss_list = []
inputs = student_model.inputs
assert 'gt_bbox' in inputs
assert self.feat_distill_place in student_distill_pairs
assert self.feat_distill_place in teacher_distill_pairs
stu_feats = student_distill_pairs[self.feat_distill_place]
tea_feats = teacher_distill_pairs[self.feat_distill_place]
for i, loss_module in enumerate(self.distill_feat_loss_modules):
feat_loss_list.append(
loss_module(stu_feats[i], tea_feats[i], inputs))
feat_loss = paddle.add_n(feat_loss_list)
else:
feat_loss = paddle.zeros([1])
student_model.yolo_head.distill_pairs.clear()
teacher_model.yolo_head.distill_pairs.clear()
return logits_loss * self.loss_weight_logits, feat_loss * self.loss_weight_feat
@register
class ChannelWiseDivergence(nn.Layer):
def __init__(self, student_channels, teacher_channels, tau=1.0, weight=1.0):
super(ChannelWiseDivergence, self).__init__()
class CWDFeatureLoss(nn.Layer):
def __init__(self,
student_channels,
teacher_channels,
normalize=False,
tau=1.0,
weight=1.0):
super(CWDFeatureLoss, self).__init__()
self.normalize = normalize
self.tau = tau
self.loss_weight = weight
......@@ -325,20 +421,23 @@ class ChannelWiseDivergence(nn.Layer):
else:
self.align = None
def distill_softmax(self, x, t):
def distill_softmax(self, x, tau):
_, _, w, h = paddle.shape(x)
x = paddle.reshape(x, [-1, w * h])
x /= t
x /= tau
return F.softmax(x, axis=1)
def forward(self, preds_s, preds_t):
assert preds_s.shape[-2:] == preds_t.shape[
-2:], 'the output dim of teacher and student differ'
N, C, W, H = preds_s.shape
def forward(self, preds_s, preds_t, inputs):
assert preds_s.shape[-2:] == preds_t.shape[-2:]
N, C, H, W = preds_s.shape
eps = 1e-5
if self.align is not None:
preds_s = self.align(preds_s)
if self.normalize:
preds_s = feature_norm(preds_s)
preds_t = feature_norm(preds_t)
softmax_pred_s = self.distill_softmax(preds_s, self.tau)
softmax_pred_t = self.distill_softmax(preds_t, self.tau)
......@@ -347,73 +446,16 @@ class ChannelWiseDivergence(nn.Layer):
return self.loss_weight * loss / (C * N)
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
super(DistillYOLOv3Loss, self).__init__()
self.weight = weight
def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
loss_w = paddle.abs(sw - tw)
loss_h = paddle.abs(sh - th)
loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
return weighted_loss
def obj_weighted_cls(self, scls, tcls, tobj):
loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
return weighted_loss
def obj_loss(self, sobj, tobj):
obj_mask = paddle.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
loss = paddle.mean(
ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
def forward(self, teacher_model, student_model):
teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
student_distill_pairs = student_model.yolo_head.loss.distill_pairs
distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
distill_reg_loss.append(
self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
distill_cls_loss.append(
self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
distill_reg_loss = paddle.add_n(distill_reg_loss)
distill_cls_loss = paddle.add_n(distill_cls_loss)
distill_obj_loss = paddle.add_n(distill_obj_loss)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.weight
return loss
def parameter_init(mode="kaiming", value=0.):
if mode == "kaiming":
weight_attr = paddle.nn.initializer.KaimingUniform()
elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
weight_init = ParamAttr(initializer=weight_attr)
return weight_init
@register
class FGDFeatureLoss(nn.Layer):
"""
Focal and Global Knowledge Distillation for Detectors
The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
Paddle version of `Focal and Global Knowledge Distillation for Detectors`
Args:
student_channels(int): The number of channels in the student's FPN feature map. Default to 256.
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256.
student_channels (int): The number of channels in the student's FPN feature map. Default to 256.
teacher_channels (int): The number of channels in the teacher's FPN feature map. Default to 256.
normalize (bool): Whether to normalize the feature maps.
temp (float, optional): The temperature coefficient. Defaults to 0.5.
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005
......@@ -422,20 +464,23 @@ class FGDFeatureLoss(nn.Layer):
"""
def __init__(self,
student_channels=256,
teacher_channels=256,
student_channels,
teacher_channels,
normalize=False,
loss_weight=1.0,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005):
super(FGDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
kaiming_init = parameter_init("kaiming")
zeros_init = parameter_init("constant", 0.0)
......@@ -486,7 +531,6 @@ class FGDFeatureLoss(nn.Layer):
def spatial_channel_attention(self, x, t=0.5):
shape = paddle.shape(x)
N, C, H, W = shape
_f = paddle.abs(x)
spatial_map = paddle.reshape(
paddle.mean(
......@@ -515,7 +559,6 @@ class FGDFeatureLoss(nn.Layer):
context_mask = context_mask.unsqueeze(-1)
context = paddle.matmul(x_copy, context_mask)
context = paddle.reshape(context, [batch, channel, 1, 1])
return context
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att,
......@@ -525,44 +568,35 @@ class FGDFeatureLoss(nn.Layer):
mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
stu_spatial_att, tea_spatial_att)
return mask_loss
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg,
def feature_loss(self, stu_feature, tea_feature, mask_fg, mask_bg,
tea_channel_att, tea_spatial_att):
Mask_fg = Mask_fg.unsqueeze(axis=1)
Mask_bg = Mask_bg.unsqueeze(axis=1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
mask_fg = mask_fg.unsqueeze(axis=1)
mask_bg = mask_bg.unsqueeze(axis=1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1).unsqueeze(axis=-1)
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1)
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att))
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(mask_bg))
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att))
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att))
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(mask_bg))
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(mask_bg)
return fg_loss, bg_loss
def relation_loss(self, stu_feature, tea_feature):
context_s = self.spatial_pool(stu_feature, "student")
context_t = self.spatial_pool(tea_feature, "teacher")
out_s = stu_feature + self.stu_conv_block(context_s)
out_t = tea_feature + self.tea_conv_block(context_t)
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
return rela_loss
def mask_value(self, mask, xl, xr, yl, yr, value):
......@@ -570,21 +604,12 @@ class FGDFeatureLoss(nn.Layer):
return mask
def forward(self, stu_feature, tea_feature, inputs):
"""Forward function.
Args:
stu_feature(Tensor): Bs*C*H*W, student's feature map
tea_feature(Tensor): Bs*C*H*W, teacher's feature map
inputs: The inputs with gt bbox and input shape info.
"""
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.'
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
assert stu_feature.shape[-2:] == stu_feature.shape[-2:]
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys()
gt_bboxes = inputs['gt_bbox']
ins_shape = [
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
]
index_gt = []
for i in range(len(gt_bboxes)):
if gt_bboxes[i].size > 2:
......@@ -592,35 +617,41 @@ class FGDFeatureLoss(nn.Layer):
# only distill feature with labeled GTbox
if len(index_gt) != len(gt_bboxes):
index_gt_t = paddle.to_tensor(index_gt)
preds_S = paddle.index_select(preds_S, index_gt_t)
preds_T = paddle.index_select(preds_T, index_gt_t)
stu_feature = paddle.index_select(stu_feature, index_gt_t)
tea_feature = paddle.index_select(tea_feature, index_gt_t)
ins_shape = [ins_shape[c] for c in index_gt]
gt_bboxes = [gt_bboxes[c] for c in index_gt]
assert len(gt_bboxes) == preds_T.shape[
0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]."
assert len(gt_bboxes) == tea_feature.shape[0]
if self.align is not None:
stu_feature = self.align(stu_feature)
N, C, H, W = stu_feature.shape
if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
tea_spatial_att, tea_channel_att = self.spatial_channel_attention(
tea_feature, self.temp)
stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
stu_feature, self.temp)
Mask_fg = paddle.zeros(tea_spatial_att.shape)
Mask_bg = paddle.ones_like(tea_spatial_att)
mask_fg = paddle.zeros(tea_spatial_att.shape)
mask_bg = paddle.ones_like(tea_spatial_att)
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]])
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]])
Mask_fg.stop_gradient = True
Mask_bg.stop_gradient = True
mask_fg.stop_gradient = True
mask_bg.stop_gradient = True
one_tmp.stop_gradient = True
zero_tmp.stop_gradient = True
wmin, wmax, hmin, hmax, area = [], [], [], [], []
wmin, wmax, hmin, hmax = [], [], [], []
if gt_bboxes.shape[1] == 0:
loss = self.relation_loss(stu_feature, tea_feature)
return self.lambda_fgd * loss
N, _, H, W = stu_feature.shape
for i in range(N):
tmp_box = paddle.ones_like(gt_bboxes[i])
tmp_box.stop_gradient = True
......@@ -633,7 +664,6 @@ class FGDFeatureLoss(nn.Layer):
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32")
zero.stop_gradient = True
ones.stop_gradient = True
wmin.append(
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero))
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32"))
......@@ -646,164 +676,217 @@ class FGDFeatureLoss(nn.Layer):
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
for j in range(len(gt_bboxes[i])):
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j])
if gt_bboxes[i][j].sum() > 0:
mask_fg[i] = self.mask_value(
mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j])
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
mask_bg[i] = paddle.where(mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
if paddle.sum(Mask_bg[i]):
Mask_bg[i] /= paddle.sum(Mask_bg[i])
if paddle.sum(mask_bg[i]):
mask_bg[i] /= paddle.sum(mask_bg[i])
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg,
Mask_bg, tea_channel_att,
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, mask_fg,
mask_bg, tea_channel_att,
tea_spatial_att)
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att,
stu_spatial_att, tea_spatial_att)
rela_loss = self.relation_loss(stu_feature, tea_feature)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss
class LDDistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__()
self.student_model = create(cfg.architecture)
logger.debug('Load student model pretrain_weights:{}'.format(
cfg.pretrain_weights))
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
slim_cfg = load_config(slim_cfg) #rewrite student cfg
self.teacher_model = create(slim_cfg.architecture)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters():
param.trainable = False
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
#student_loss = self.student_model(inputs)
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
return loss * self.loss_weight
@register
class KnowledgeDistillationKLDivLoss(nn.Layer):
"""Loss function for knowledge distilling using KL divergence.
class PKDFeatureLoss(nn.Layer):
"""
PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
loss_weight (float): Weight of loss. Defaults to 1.0.
resize_stu (bool): If True, we'll down/up sample the features of the
student model to the spatial size of those of the teacher model if
their spatial sizes are different. And vice versa. Defaults to
True.
"""
def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(KnowledgeDistillationKLDivLoss, self).__init__()
assert reduction in ('none', 'mean', 'sum')
assert T >= 1
self.reduction = reduction
def __init__(self,
student_channels=256,
teacher_channels=256,
normalize=True,
loss_weight=1.0,
resize_stu=True):
super(PKDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
self.T = T
self.resize_stu = resize_stu
def knowledge_distillation_kl_div_loss(self,
pred,
soft_label,
T,
detach_target=True):
r"""Loss function for knowledge distilling using KL divergence.
def forward(self, stu_feature, tea_feature, inputs):
size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:]
if size_s[0] != size_t[0]:
if self.resize_stu:
stu_feature = F.interpolate(
stu_feature, size_t, mode='bilinear')
else:
tea_feature = F.interpolate(
tea_feature, size_s, mode='bilinear')
assert stu_feature.shape == tea_feature.shape
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
loss = F.mse_loss(stu_feature, tea_feature) / 2
return loss * self.loss_weight
assert pred.shape == soft_label.shape
target = F.softmax(soft_label / T, axis=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div(
F.log_softmax(
pred / T, axis=1), target, reduction='none').mean(1) * (T * T)
@register
class MimicFeatureLoss(nn.Layer):
def __init__(self,
student_channels=256,
teacher_channels=256,
normalize=True,
loss_weight=1.0):
super(MimicFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
self.mse_loss = nn.MSELoss()
return kd_loss
if student_channels != teacher_channels:
self.align = nn.Conv2D(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None
def forward(self,
pred,
soft_label,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
def forward(self, stu_feature, tea_feature, inputs):
if self.align is not None:
stu_feature = self.align(stu_feature)
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
reduction = (reduction_override
if reduction_override else self.reduction)
loss = self.mse_loss(stu_feature, tea_feature)
return loss * self.loss_weight
loss_kd_out = self.knowledge_distillation_kl_div_loss(
pred, soft_label, T=self.T)
if weight is not None:
loss_kd_out = weight * loss_kd_out
@register
class MGDFeatureLoss(nn.Layer):
def __init__(self,
student_channels=256,
teacher_channels=256,
normalize=True,
loss_weight=1.0,
loss_func='mse'):
super(MGDFeatureLoss, self).__init__()
self.normalize = normalize
self.loss_weight = loss_weight
assert loss_func in ['mse', 'ssim']
self.loss_func = loss_func
self.mse_loss = nn.MSELoss(reduction='sum')
self.ssim_loss = SSIM(11)
if avg_factor is None:
if reduction == 'none':
loss = loss_kd_out
elif reduction == 'mean':
loss = loss_kd_out.mean()
elif reduction == 'sum':
loss = loss_kd_out.sum()
kaiming_init = parameter_init("kaiming")
if student_channels != teacher_channels:
self.align = nn.Conv2D(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=kaiming_init,
bias_attr=False)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss_kd_out.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
self.align = None
loss_kd = self.loss_weight * loss
self.generation = nn.Sequential(
nn.Conv2D(
teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2D(
teacher_channels, teacher_channels, kernel_size=3, padding=1))
return loss_kd
def forward(self, stu_feature, tea_feature, inputs):
N = stu_feature.shape[0]
if self.align is not None:
stu_feature = self.align(stu_feature)
stu_feature = self.generation(stu_feature)
if self.normalize:
stu_feature = feature_norm(stu_feature)
tea_feature = feature_norm(tea_feature)
if self.loss_func == 'mse':
loss = self.mse_loss(stu_feature, tea_feature) / N
elif self.loss_func == 'ssim':
ssim_loss = self.ssim_loss(stu_feature, tea_feature)
loss = paddle.clip((1 - ssim_loss) / 2, 0, 1)
else:
raise ValueError
return loss * self.loss_weight
class SSIM(nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = self.create_window(window_size, self.channel)
def gaussian(self, window_size, sigma):
gauss = paddle.to_tensor([
math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(self, window_size, channel):
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = _2D_window.expand([channel, 1, window_size, window_size])
return window
def _ssim(self, img1, img2, window, window_size, channel,
size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2,
groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2,
groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
1e-12 + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean([1, 2, 3])
def forward(self, img1, img2):
channel = img1.shape[1]
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = self.create_window(self.window_size, channel)
self.window = window
self.channel = channel
return self._ssim(img1, img2, window, self.window_size, channel,
self.size_average)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'DistillModel',
'FGDDistillModel',
'CWDDistillModel',
'LDDistillModel',
'PPYOLOEDistillModel',
]
@register
class DistillModel(nn.Layer):
"""
Build common distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.arch = cfg.architecture
self.stu_cfg = cfg
self.student_model = create(self.stu_cfg.architecture)
if 'pretrain_weights' in self.stu_cfg and self.stu_cfg.pretrain_weights:
stu_pretrain = self.stu_cfg.pretrain_weights
else:
stu_pretrain = None
slim_cfg = load_config(slim_cfg)
self.tea_cfg = slim_cfg
self.teacher_model = create(self.tea_cfg.architecture)
if 'pretrain_weights' in self.tea_cfg and self.tea_cfg.pretrain_weights:
tea_pretrain = self.tea_cfg.pretrain_weights
else:
tea_pretrain = None
self.distill_cfg = slim_cfg
# load pretrain weights
self.is_inherit = False
if stu_pretrain:
if self.is_inherit and tea_pretrain:
load_pretrain_weight(self.student_model, tea_pretrain)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
logger.info("Student model has loaded pretrain weights!")
if tea_pretrain:
load_pretrain_weight(self.teacher_model, tea_pretrain)
logger.info("Teacher model has loaded pretrain weights!")
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
self.distill_loss = self.build_loss(self.distill_cfg)
def build_loss(self, distill_cfg):
if 'distill_loss' in distill_cfg and distill_cfg.distill_loss:
return create(distill_cfg.distill_loss)
else:
return None
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
loss = self.distill_loss(self.teacher_model, self.student_model)
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss']
student_loss['loss'] += student_loss['distill_loss']
return student_loss
else:
return self.student_model(inputs)
@register
class FGDDistillModel(DistillModel):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['RetinaNet', 'PicoDet'
], 'Unsupported arch: {}'.format(self.arch)
self.is_inherit = True
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.distill_loss):
loss_dict[k] = self.distill_loss[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
@register
class CWDDistillModel(DistillModel):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(CWDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL', 'RetinaNet'], 'Unsupported arch: {}'.format(
self.arch)
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
loss = self.student_model.head(stu_fea_list, inputs)
distill_loss = {}
for idx, k in enumerate(self.loss_dic):
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
loss['loss'] += distill_loss[k]
loss[k] = distill_loss[k]
return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
loss = {}
head_outs = self.student_model.head(stu_fea_list)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
loss.update(loss_gfl)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
if self.arch == "RetinaNet":
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
inputs)
elif self.arch == "GFL":
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
else:
raise ValueError(f"unsupported arch {self.arch}")
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "GFL":
bbox_pred, bbox_num = head_outs
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
raise ValueError(f"unsupported arch {self.arch}")
@register
class LDDistillModel(DistillModel):
"""
Build LD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL'], 'Unsupported arch: {}'.format(self.arch)
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
@register
class PPYOLOEDistillModel(DistillModel):
"""
Build PPYOLOE distill model, only used in PPYOLOE
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(PPYOLOEDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['PPYOLOE'], 'Unsupported arch: {}'.format(
self.arch)
def forward(self, inputs, alpha=0.125):
if self.training:
if hasattr(self.teacher_model.yolo_head, "assigned_labels"):
self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \
self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive
delattr(self.teacher_model.yolo_head, "assigned_labels")
delattr(self.teacher_model.yolo_head, "assigned_bboxes")
delattr(self.teacher_model.yolo_head, "assigned_scores")
delattr(self.teacher_model.yolo_head, "mask_positive")
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
logits_loss, feat_loss = self.distill_loss(self.teacher_model,
self.student_model)
det_total_loss = student_loss['loss']
total_loss = alpha * (det_total_loss + logits_loss + feat_loss)
student_loss['loss'] = total_loss
student_loss['det_loss'] = det_total_loss
student_loss['logits_loss'] = logits_loss
student_loss['feat_loss'] = feat_loss
return student_loss
else:
return self.student_model(inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册