未验证 提交 b2e0acf1 编写于 作者: F Feng Ni 提交者: GitHub

Add PPYOLOE distillation (#7671)

* add ppyoloe distill with soft loss and feature loss

* refine distill codes

* fix configs and docs

* clean codes

* merge cam, fix export
上级 2b5fd266
# PPYOLOE+ Distillation(PPYOLOE+ 蒸馏)
PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。
## 模型库
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
_BASE_: [
'../ppyoloe_plus_crn_l_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
_BASE_: [
'../ppyoloe_plus_crn_m_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_m_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_m_obj365_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
_BASE_: [
'../ppyoloe_plus_crn_s_80e_coco.yml',
]
for_distill: True
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
...@@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https:// ...@@ -38,6 +38,42 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https://
|gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | |gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) |
|gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | |gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) |
## PPYOLOE+模型蒸馏
## 快速开始
### 训练
```shell
# 单卡
python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
# 多卡
python3.7 -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_to_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_to_l.yml
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
### 评估
```shell
python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams
```
- `-c`: 指定模型配置文件,也是student配置文件。
- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
### 测试
```shell
python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
- `-o weights`: 指定压缩算法训好的模型路径。
- `--infer_img`: 指定测试图像路径。
## Citations ## Citations
``` ```
......
...@@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_ ...@@ -6,10 +6,10 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_
slim: Distill slim: Distill
slim_method: CWD slim_method: CWD
distill_loss: ChannelWiseDivergence distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence: CWDFeatureLoss:
student_channels: 80 student_channels: 80
teacher_channels: 80 teacher_channels: 80
tau: 1.0 tau: 1.0
......
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 1.0
width_mult: 1.0
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # L -> M
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.0 # L
student_width_mult: 0.75 # M
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
depth_mult: 0.67
width_mult: 0.75
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # M -> S
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 0.75 # M
student_width_mult: 0.5 # S
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
# teacher and slim config
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml',
]
depth_mult: 1.33
width_mult: 1.25
architecture: PPYOLOE
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
for_distill: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams
find_unused_parameters: True
for_distill: True
slim: Distill
slim_method: PPYOLOEDistill
distill_loss: DistillPPYOLOELoss
DistillPPYOLOELoss: # X -> L
loss_weight: {'logits': 4.0, 'feat': 1.0}
logits_distill: True
logits_loss_weight: {'class': 1.0, 'iou': 2.5, 'dfl': 0.5}
feat_distill: True
feat_distiller: 'fgd' # ['cwd', 'fgd', 'pkd', 'mgd', 'mimic']
feat_distill_place: 'neck_feats'
teacher_width_mult: 1.25 # X
student_width_mult: 1.0 # L
feat_out_channels: [768, 384, 192] # The actual channel will multiply width_mult
...@@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c ...@@ -7,12 +7,11 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_c
slim: Distill slim: Distill
slim_method: CWD slim_method: CWD
distill_loss: ChannelWiseDivergence distill_loss: CWDFeatureLoss
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0'] distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence: CWDFeatureLoss:
student_channels: 80 student_channels: 80
teacher_channels: 80 teacher_channels: 80
name: cwdloss
tau: 1.0 tau: 1.0
weight: 5.0 weight: 5.0
...@@ -22,13 +22,14 @@ from ppdet.core.workspace import register, create ...@@ -22,13 +22,14 @@ from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture # PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py # PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head
@register @register
class PPYOLOE(BaseArch): class PPYOLOE(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
__shared__ = ['for_distill']
__inject__ = ['post_process'] __inject__ = ['post_process']
def __init__(self, def __init__(self,
...@@ -36,6 +37,8 @@ class PPYOLOE(BaseArch): ...@@ -36,6 +37,8 @@ class PPYOLOE(BaseArch):
neck='CustomCSPPAN', neck='CustomCSPPAN',
yolo_head='PPYOLOEHead', yolo_head='PPYOLOEHead',
post_process='BBoxPostProcess', post_process='BBoxPostProcess',
for_distill=False,
feat_distill_place='neck_feats',
for_mot=False): for_mot=False):
""" """
PPYOLOE network, see https://arxiv.org/abs/2203.16250 PPYOLOE network, see https://arxiv.org/abs/2203.16250
...@@ -54,6 +57,10 @@ class PPYOLOE(BaseArch): ...@@ -54,6 +57,10 @@ class PPYOLOE(BaseArch):
self.yolo_head = yolo_head self.yolo_head = yolo_head
self.post_process = post_process self.post_process = post_process
self.for_mot = for_mot self.for_mot = for_mot
self.for_distill = for_distill
self.feat_distill_place = feat_distill_place
if for_distill:
assert feat_distill_place in ['backbone_feats', 'neck_feats']
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -80,17 +87,31 @@ class PPYOLOE(BaseArch): ...@@ -80,17 +87,31 @@ class PPYOLOE(BaseArch):
if self.training: if self.training:
yolo_losses = self.yolo_head(neck_feats, self.inputs) yolo_losses = self.yolo_head(neck_feats, self.inputs)
if self.for_distill:
if self.feat_distill_place == 'backbone_feats':
self.yolo_head.distill_pairs['backbone_feats'] = body_feats
elif self.feat_distill_place == 'neck_feats':
self.yolo_head.distill_pairs['neck_feats'] = neck_feats
else:
raise ValueError
return yolo_losses return yolo_losses
else: else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats) yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None: if self.post_process is not None:
bbox, bbox_num = self.post_process( bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else: else:
bbox, bbox_num = self.yolo_head.post_process( bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor']) yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num} # data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output return output
...@@ -180,15 +201,21 @@ class PPYOLOEWithAuxHead(BaseArch): ...@@ -180,15 +201,21 @@ class PPYOLOEWithAuxHead(BaseArch):
aux_pred=[aux_cls_scores, aux_bbox_preds]) aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss return loss
else: else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats) yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None: if self.post_process is not None:
bbox, bbox_num = self.post_process( bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else: else:
bbox, bbox_num = self.yolo_head.post_process( bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor']) yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num} # data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output return output
......
...@@ -22,7 +22,7 @@ from ..post_process import JDEBBoxPostProcess ...@@ -22,7 +22,7 @@ from ..post_process import JDEBBoxPostProcess
__all__ = ['YOLOv3'] __all__ = ['YOLOv3']
# YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3 # YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py # PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py, especially when use distillation or aux head
@register @register
......
...@@ -221,4 +221,4 @@ class ATSSAssigner(nn.Layer): ...@@ -221,4 +221,4 @@ class ATSSAssigner(nn.Layer):
paddle.zeros_like(gather_scores)) paddle.zeros_like(gather_scores))
assigned_scores *= gather_scores.unsqueeze(-1) assigned_scores *= gather_scores.unsqueeze(-1)
return assigned_labels, assigned_bboxes, assigned_scores return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer): ...@@ -190,4 +190,4 @@ class TaskAlignedAssigner(nn.Layer):
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores return assigned_labels, assigned_bboxes, assigned_scores, mask_positive
...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead): ...@@ -651,7 +651,7 @@ class PicoHeadV2(GFLHead):
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead): ...@@ -662,7 +662,7 @@ class PicoHeadV2(GFLHead):
pred_bboxes=pred_bboxes.detach() * stride_tensor_list) pred_bboxes=pred_bboxes.detach() * stride_tensor_list)
else: else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list, pred_bboxes.detach() * stride_tensor_list,
centers, centers,
......
...@@ -136,7 +136,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -136,7 +136,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -148,7 +148,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -148,7 +148,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
alpha_l = 0.25 alpha_l = 0.25
else: else:
if self.sm_use: if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -159,7 +159,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -159,7 +159,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
pad_gt_mask, pad_gt_mask,
bg_index=self.num_classes) bg_index=self.num_classes)
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
......
...@@ -53,7 +53,7 @@ class ESEAttn(nn.Layer): ...@@ -53,7 +53,7 @@ class ESEAttn(nn.Layer):
class PPYOLOEHead(nn.Layer): class PPYOLOEHead(nn.Layer):
__shared__ = [ __shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms', 'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv' 'exclude_post_process', 'use_shared_conv', 'for_distill'
] ]
__inject__ = ['static_assigner', 'assigner', 'nms'] __inject__ = ['static_assigner', 'assigner', 'nms']
...@@ -81,7 +81,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -81,7 +81,8 @@ class PPYOLOEHead(nn.Layer):
attn_conv='convbn', attn_conv='convbn',
exclude_nms=False, exclude_nms=False,
exclude_post_process=False, exclude_post_process=False,
use_shared_conv=True): use_shared_conv=True,
for_distill=False):
super(PPYOLOEHead, self).__init__() super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0" assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels self.in_channels = in_channels
...@@ -110,6 +111,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -110,6 +111,7 @@ class PPYOLOEHead(nn.Layer):
self.exclude_nms = exclude_nms self.exclude_nms = exclude_nms
self.exclude_post_process = exclude_post_process self.exclude_post_process = exclude_post_process
self.use_shared_conv = use_shared_conv self.use_shared_conv = use_shared_conv
self.for_distill = for_distill
# stem # stem
self.stem_cls = nn.LayerList() self.stem_cls = nn.LayerList()
...@@ -135,6 +137,9 @@ class PPYOLOEHead(nn.Layer): ...@@ -135,6 +137,9 @@ class PPYOLOEHead(nn.Layer):
self.proj_conv.skip_quant = True self.proj_conv.skip_quant = True
self._init_weights() self._init_weights()
if self.for_distill:
self.distill_pairs = {}
@classmethod @classmethod
def from_config(cls, cfg, input_shape): def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], } return {'in_channels': [i.channels for i in input_shape], }
...@@ -321,6 +326,10 @@ class PPYOLOEHead(nn.Layer): ...@@ -321,6 +326,10 @@ class PPYOLOEHead(nn.Layer):
loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos, loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos,
self.reg_range[0]) * bbox_weight self.reg_range[0]) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum loss_dfl = loss_dfl.sum() / assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos
self.distill_pairs['pred_dist_pos'] = pred_dist_pos
self.distill_pairs['bbox_weight'] = bbox_weight
else: else:
loss_l1 = paddle.zeros([1]) loss_l1 = paddle.zeros([1])
loss_iou = paddle.zeros([1]) loss_iou = paddle.zeros([1])
...@@ -343,7 +352,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -343,7 +352,7 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.static_assigner( self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
...@@ -356,7 +365,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -356,7 +365,7 @@ class PPYOLOEHead(nn.Layer):
else: else:
if self.sm_use: if self.sm_use:
# only used in smalldet of PPYOLOE-SOD model # only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
...@@ -368,18 +377,28 @@ class PPYOLOEHead(nn.Layer): ...@@ -368,18 +377,28 @@ class PPYOLOEHead(nn.Layer):
bg_index=self.num_classes) bg_index=self.num_classes)
else: else:
if aux_pred is None: if aux_pred is None:
assigned_labels, assigned_bboxes, assigned_scores = \ if not hasattr(self, "assigned_labels"):
self.assigner( assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
pred_scores.detach(), self.assigner(
pred_bboxes.detach() * stride_tensor, pred_scores.detach(),
anchor_points, pred_bboxes.detach() * stride_tensor,
num_anchors_list, anchor_points,
gt_labels, num_anchors_list,
gt_bboxes, gt_labels,
pad_gt_mask, gt_bboxes,
bg_index=self.num_classes) pad_gt_mask,
bg_index=self.num_classes)
self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
else:
assigned_labels = self.assigned_labels
assigned_bboxes = self.assigned_bboxes
assigned_scores = self.assigned_scores
mask_positive = self.mask_positive
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
self.assigner( self.assigner(
pred_scores_aux.detach(), pred_scores_aux.detach(),
pred_bboxes_aux.detach() * stride_tensor, pred_bboxes_aux.detach() * stride_tensor,
...@@ -395,12 +414,14 @@ class PPYOLOEHead(nn.Layer): ...@@ -395,12 +414,14 @@ class PPYOLOEHead(nn.Layer):
assign_out_dict = self.get_loss_from_assign( assign_out_dict = self.get_loss_from_assign(
pred_scores, pred_distri, pred_bboxes, anchor_points_s, pred_scores, pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l) assigned_labels, assigned_bboxes, assigned_scores, mask_positive,
alpha_l)
if aux_pred is not None: if aux_pred is not None:
assign_out_dict_aux = self.get_loss_from_assign( assign_out_dict_aux = self.get_loss_from_assign(
aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s, aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l) assigned_labels, assigned_bboxes, assigned_scores,
mask_positive, alpha_l)
loss = {} loss = {}
for key in assign_out_dict.keys(): for key in assign_out_dict.keys():
loss[key] = assign_out_dict[key] + assign_out_dict_aux[key] loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
...@@ -411,7 +432,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -411,7 +432,7 @@ class PPYOLOEHead(nn.Layer):
def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes, def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
anchor_points_s, assigned_labels, assigned_bboxes, anchor_points_s, assigned_labels, assigned_bboxes,
assigned_scores, alpha_l): assigned_scores, mask_positive, alpha_l):
# cls loss # cls loss
if self.use_varifocal_loss: if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, one_hot_label = F.one_hot(assigned_labels,
...@@ -428,6 +449,15 @@ class PPYOLOEHead(nn.Layer): ...@@ -428,6 +449,15 @@ class PPYOLOEHead(nn.Layer):
assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
loss_cls /= assigned_scores_sum loss_cls /= assigned_scores_sum
if self.for_distill:
self.distill_pairs['pred_cls_scores'] = pred_scores
self.distill_pairs['pos_num'] = assigned_scores_sum
self.distill_pairs['assigned_scores'] = assigned_scores
self.distill_pairs['mask_positive'] = mask_positive
one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1]
self.distill_pairs['target_labels'] = one_hot_label
loss_l1, loss_iou, loss_dfl = \ loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, assigned_labels, assigned_bboxes, assigned_scores,
...@@ -450,7 +480,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -450,7 +480,8 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes *= stride_tensor pred_bboxes *= stride_tensor
if self.exclude_post_process: if self.exclude_post_process:
return paddle.concat( return paddle.concat(
[pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1), None [pred_bboxes, pred_scores.transpose([0, 2, 1])],
axis=-1), None, None
else: else:
# scale bbox to origin # scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1) scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
...@@ -460,9 +491,10 @@ class PPYOLOEHead(nn.Layer): ...@@ -460,9 +491,10 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes /= scale_factor pred_bboxes /= scale_factor
if self.exclude_nms: if self.exclude_nms:
# `exclude_nms=True` just use in benchmark # `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores return pred_bboxes, pred_scores, None
else: else:
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, pred_scores) bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
pred_scores)
return bbox_pred, bbox_num, before_nms_indexes return bbox_pred, bbox_num, before_nms_indexes
......
...@@ -258,7 +258,7 @@ class PPYOLOERHead(nn.Layer): ...@@ -258,7 +258,7 @@ class PPYOLOERHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.static_assigner( self.static_assigner(
anchor_points, anchor_points,
stride_tensor, stride_tensor,
...@@ -271,7 +271,7 @@ class PPYOLOERHead(nn.Layer): ...@@ -271,7 +271,7 @@ class PPYOLOERHead(nn.Layer):
pred_bboxes.detach() pred_bboxes.detach()
) )
else: else:
assigned_labels, assigned_bboxes, assigned_scores = \ assigned_labels, assigned_bboxes, assigned_scores, _ = \
self.assigner( self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach(), pred_bboxes.detach(),
......
...@@ -293,7 +293,7 @@ class TOODHead(nn.Layer): ...@@ -293,7 +293,7 @@ class TOODHead(nn.Layer):
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment # label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch: if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.static_assigner(
anchors, anchors,
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -302,7 +302,7 @@ class TOODHead(nn.Layer): ...@@ -302,7 +302,7 @@ class TOODHead(nn.Layer):
bg_index=self.num_classes) bg_index=self.num_classes)
alpha_l = 0.25 alpha_l = 0.25
else: else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner( assigned_labels, assigned_bboxes, assigned_scores, _ = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, pred_bboxes.detach() * stride_tensor,
bbox_center(anchors), bbox_center(anchors),
......
...@@ -12,16 +12,19 @@ ...@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import distill_loss
from . import distill_model
from . import ofa
from . import prune from . import prune
from . import quant from . import quant
from . import distill
from . import unstructured_prune from . import unstructured_prune
from .distill_loss import *
from .distill_model import *
from .ofa import *
from .prune import * from .prune import *
from .quant import * from .quant import *
from .distill import *
from .unstructured_prune import * from .unstructured_prune import *
from .ofa import *
import yaml import yaml
from ppdet.core.workspace import load_config from ppdet.core.workspace import load_config
...@@ -45,7 +48,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'): ...@@ -45,7 +48,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
elif "slim_method" in slim_load_cfg and slim_load_cfg[ elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "CWD": 'slim_method'] == "CWD":
model = CWDDistillModel(cfg, slim_cfg) model = CWDDistillModel(cfg, slim_cfg)
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "PPYOLOEDistill":
model = PPYOLOEDistillModel(cfg, slim_cfg)
else: else:
# common distillation model
model = DistillModel(cfg, slim_cfg) model = DistillModel(cfg, slim_cfg)
cfg['model'] = model cfg['model'] = model
cfg['slim_type'] = cfg.slim cfg['slim_type'] = cfg.slim
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'DistillModel',
'FGDDistillModel',
'CWDDistillModel',
'LDDistillModel',
'PPYOLOEDistillModel',
]
@register
class DistillModel(nn.Layer):
"""
Build common distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(DistillModel, self).__init__()
self.arch = cfg.architecture
self.stu_cfg = cfg
self.student_model = create(self.stu_cfg.architecture)
if 'pretrain_weights' in self.stu_cfg and self.stu_cfg.pretrain_weights:
stu_pretrain = self.stu_cfg.pretrain_weights
else:
stu_pretrain = None
slim_cfg = load_config(slim_cfg)
self.tea_cfg = slim_cfg
self.teacher_model = create(self.tea_cfg.architecture)
if 'pretrain_weights' in self.tea_cfg and self.tea_cfg.pretrain_weights:
tea_pretrain = self.tea_cfg.pretrain_weights
else:
tea_pretrain = None
self.distill_cfg = slim_cfg
# load pretrain weights
self.is_inherit = False
if stu_pretrain:
if self.is_inherit and tea_pretrain:
load_pretrain_weight(self.student_model, tea_pretrain)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
logger.info("Student model has loaded pretrain weights!")
if tea_pretrain:
load_pretrain_weight(self.teacher_model, tea_pretrain)
logger.info("Teacher model has loaded pretrain weights!")
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
self.distill_loss = self.build_loss(self.distill_cfg)
def build_loss(self, distill_cfg):
if 'distill_loss' in distill_cfg and distill_cfg.distill_loss:
return create(distill_cfg.distill_loss)
else:
return None
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
loss = self.distill_loss(self.teacher_model, self.student_model)
student_loss['distill_loss'] = loss
student_loss['teacher_loss'] = teacher_loss['loss']
student_loss['loss'] += student_loss['distill_loss']
return student_loss
else:
return self.student_model(inputs)
@register
class FGDDistillModel(DistillModel):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['RetinaNet', 'PicoDet'
], 'Unsupported arch: {}'.format(self.arch)
self.is_inherit = True
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.distill_loss):
loss_dict[k] = self.distill_loss[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else:
raise ValueError(f"Unsupported model {self.arch}")
@register
class CWDDistillModel(DistillModel):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(CWDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL', 'RetinaNet'], 'Unsupported arch: {}'.format(
self.arch)
def build_loss(self, distill_cfg):
assert 'distill_loss_name' in distill_cfg and distill_cfg.distill_loss_name
assert 'distill_loss' in distill_cfg and distill_cfg.distill_loss
loss_func = dict()
name_list = distill_cfg.distill_loss_name
for name in name_list:
loss_func[name] = create(distill_cfg.distill_loss)
return loss_func
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
loss = self.student_model.head(stu_fea_list, inputs)
distill_loss = {}
for idx, k in enumerate(self.loss_dic):
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
loss['loss'] += distill_loss[k]
loss[k] = distill_loss[k]
return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
loss = {}
head_outs = self.student_model.head(stu_fea_list)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
loss.update(loss_gfl)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
if self.arch == "RetinaNet":
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
inputs)
elif self.arch == "GFL":
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
else:
raise ValueError(f"unsupported arch {self.arch}")
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "GFL":
bbox_pred, bbox_num = head_outs
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
raise ValueError(f"unsupported arch {self.arch}")
@register
class LDDistillModel(DistillModel):
"""
Build LD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['GFL'], 'Unsupported arch: {}'.format(self.arch)
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
@register
class PPYOLOEDistillModel(DistillModel):
"""
Build PPYOLOE distill model, only used in PPYOLOE
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(PPYOLOEDistillModel, self).__init__(cfg=cfg, slim_cfg=slim_cfg)
assert self.arch in ['PPYOLOE'], 'Unsupported arch: {}'.format(
self.arch)
def forward(self, inputs, alpha=0.125):
if self.training:
if hasattr(self.teacher_model.yolo_head, "assigned_labels"):
self.student_model.yolo_head.assigned_labels, self.student_model.yolo_head.assigned_bboxes, self.student_model.yolo_head.assigned_scores, self.student_model.yolo_head.mask_positive = \
self.teacher_model.yolo_head.assigned_labels, self.teacher_model.yolo_head.assigned_bboxes, self.teacher_model.yolo_head.assigned_scores, self.teacher_model.yolo_head.mask_positive
delattr(self.teacher_model.yolo_head, "assigned_labels")
delattr(self.teacher_model.yolo_head, "assigned_bboxes")
delattr(self.teacher_model.yolo_head, "assigned_scores")
delattr(self.teacher_model.yolo_head, "mask_positive")
student_loss = self.student_model(inputs)
with paddle.no_grad():
teacher_loss = self.teacher_model(inputs)
logits_loss, feat_loss = self.distill_loss(self.teacher_model,
self.student_model)
det_total_loss = student_loss['loss']
total_loss = alpha * (det_total_loss + logits_loss + feat_loss)
student_loss['loss'] = total_loss
student_loss['det_loss'] = det_total_loss
student_loss['logits_loss'] = logits_loss
student_loss['feat_loss'] = feat_loss
return student_loss
else:
return self.student_model(inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册