未验证 提交 c732ee15 编写于 作者: D Double_V 提交者: GitHub

[feature] add CWD distill model (#7160)

* add CWD distill code

* fix type

* fix bugs

* add retinanet teacher yml

* fix comments
上级 ca1e1efd
......@@ -9,6 +9,7 @@ We reproduce the object detection results in the paper [Generalized Focal Loss:
| Backbone | Model | batch-size/GPU | lr schedule |FPS | Box AP | download | config |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50 | GFL | 2 | 1x | ---- | 41.0 | [model](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r50_fpn_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r50_fpn_1x_coco.yml) |
| ResNet50 | GFL + [CWD](../slim/README.md) | 2 | 2x | ---- | 44.0 | [model](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r50_fpn_2x_coco_cwd.log) | [config1](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r50_fpn_1x_coco.yml), [config2](../slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml) |
| ResNet101-vd | GFL | 2 | 2x | ---- | 46.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r101vd_fpn_mstrain_2x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) |
| ResNet34-vd | GFL | 2 | 1x | ---- | 40.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r34vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r34vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r34vd_1x_coco.yml) |
| ResNet18-vd | GFL | 2 | 1x | ---- | 36.6 | [model](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r18vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r18vd_1x_coco.yml) |
......
......@@ -26,6 +26,19 @@ LD全称为[Localization Distillation for Dense Object Detection](https://arxiv.
| GFL_ResNet18-vd | student | 36.6 | [model](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams), [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r18vd_1x_coco.yml) |
| GFL_ResNet18-vd + LD | student | 38.2 | [model](https://bj.bcebos.com/v1/paddledet/models/gfl_slim_ld_r18vd_1x_coco.pdparams), [config1](../../gfl/gfl_slim_ld_r18vd_1x_coco.yml), [config2](./gfl_ld_distill.yml) |
## CWD模型蒸馏
CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https://arxiv.org/pdf/2011.13256.pdf),通过最小化教师网络与学生网络的通道概率图之间的 Kullback-Leibler (KL) 散度,使得在蒸馏过程更加关注每个通道的最显著的区域,进而提升文本检测与图像分割任务的精度。在PaddleDetection中,我们实现了CWD算法,并基于GFL和RetinaNet模型进行验证,实验结果如下:
| algorithm | model | AP | download|
|:-:| :-: | :-: | :-:|
|retinaNet_r101_fpn_2x | teacher | 40.6 | [download](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) |
|retinaNet_r50_fpn_1x| student | 37.5 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_1x_coco.pdparams) |
|retinaNet_r50_fpn_2x + CWD| student | 40.5 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_cwd.pdparams) |
|gfl_r101_fpn_2x | teacher | 46.8 | [download](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) |
|gfl_r50_fpn_1x| student | 41.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) |
|gfl_r50_fpn_2x + CWD| student | 44.0 |[download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams) |
## Citations
```
@article{mehta2018object,
......@@ -51,4 +64,12 @@ LD全称为[Localization Distillation for Dense Object Detection](https://arxiv.
booktitle={CVPR},
year={2022}
}
@inproceedings{shu2021channel,
title={Channel-wise knowledge distillation for dense prediction},
author={Shu, Changyong and Liu, Yifan and Gao, Jianfei and Yan, Zheng and Shen, Chunhua},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={5311--5320},
year={2021}
}
```
_BASE_: [
'../../gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams
slim: Distill
slim_method: CWD
distill_loss: ChannelWiseDivergence
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence:
student_channels: 80
teacher_channels: 80
tau: 1.0
weight: 5.0
_BASE_: [
'../../retinanet/retinanet_r101_fpn_2x_coco.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams
slim: Distill
slim_method: CWD
distill_loss: ChannelWiseDivergence
distill_loss_name: ['cls_f_4', 'cls_f_3', 'cls_f_2', 'cls_f_1', 'cls_f_0']
ChannelWiseDivergence:
student_channels: 80
teacher_channels: 80
name: cwdloss
tau: 1.0
weight: 5.0
......@@ -42,6 +42,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "LD":
model = LDDistillModel(cfg, slim_cfg)
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "CWD":
model = CWDDistillModel(cfg, slim_cfg)
else:
model = DistillModel(cfg, slim_cfg)
cfg['model'] = model
......
......@@ -170,6 +170,183 @@ class FGDDistillModel(nn.Layer):
raise ValueError(f"Unsupported model {self.arch}")
class CWDDistillModel(nn.Layer):
"""
Build CWD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(CWDDistillModel, self).__init__()
self.is_inherit = False
# build student model before load slim config
self.student_model = create(cfg.architecture)
self.arch = cfg.architecture
if self.arch not in ['GFL', 'RetinaNet']:
raise ValueError(
f"The arch can only be one of ['GFL', 'RetinaNet'], but received {self.arch}"
)
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg
tea_pretrain = cfg['pretrain_weights']
self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
if 'pretrain_weights' in cfg and stu_pretrain:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.student_model,
self.teacher_cfg.pretrain_weights)
logger.debug(
"Inheriting! loading teacher weights to student model!")
load_pretrain_weight(self.student_model, stu_pretrain)
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights)
self.loss_dic = self.build_loss(
self.loss_cfg.distill_loss,
name_list=self.loss_cfg['distill_loss_name'])
def build_loss(self,
cfg,
name_list=[
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0'
]):
loss_func = dict()
for idx, k in enumerate(name_list):
loss_func[k] = create(cfg)
return loss_func
def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
loss = self.student_model.head(stu_fea_list, inputs)
distill_loss = {}
# cwd kd loss
for idx, k in enumerate(self.loss_dic):
distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
loss['loss'] += distill_loss[k]
loss[k] = distill_loss[k]
return loss
def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
loss = {}
head_outs = self.student_model.head(stu_fea_list)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
loss.update(loss_gfl)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
# cwd kd loss
feat_loss = {}
loss_dict = {}
s_cls_feat, t_cls_feat = [], []
for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
s_cls_feat.append(cls_score)
t_cls_feat.append(t_cls_score)
for idx, k in enumerate(self.loss_dic):
loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
tea_fea_list[idx])
for k in feat_loss:
loss['loss'] += feat_loss[k]
loss[k] = feat_loss[k]
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
if self.arch == "RetinaNet":
loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
inputs)
elif self.arch == "GFL":
loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
else:
raise ValueError(f"unsupported arch {self.arch}")
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "GFL":
bbox_pred, bbox_num = head_outs
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
else:
raise ValueError(f"unsupported arch {self.arch}")
@register
class ChannelWiseDivergence(nn.Layer):
def __init__(self, student_channels, teacher_channels, tau=1.0, weight=1.0):
super(ChannelWiseDivergence, self).__init__()
self.tau = tau
self.loss_weight = weight
if student_channels != teacher_channels:
self.align = nn.Conv2D(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0)
else:
self.align = None
def distill_softmax(self, x, t):
_, _, w, h = paddle.shape(x)
x = paddle.reshape(x, [-1, w * h])
x /= t
return F.softmax(x, axis=1)
def forward(self, preds_s, preds_t):
assert preds_s.shape[-2:] == preds_t.shape[
-2:], 'the output dim of teacher and student differ'
N, C, W, H = preds_s.shape
eps = 1e-5
if self.align is not None:
preds_s = self.align(preds_s)
softmax_pred_s = self.distill_softmax(preds_s, self.tau)
softmax_pred_t = self.distill_softmax(preds_t, self.tau)
loss = paddle.sum(-softmax_pred_t * paddle.log(eps + softmax_pred_s) +
softmax_pred_t * paddle.log(eps + softmax_pred_t))
return self.loss_weight * loss / (C * N)
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册