未验证 提交 b01af8bd 编写于 作者: D Double_V 提交者: GitHub

[distill] add FGD distill algorithm (#6239)

* add FGD distill code
上级 3428d97f
......@@ -5,6 +5,10 @@
| Backbone | Model | imgs/GPU | lr schedule | FPS | Box AP | download | config |
| ------------ | --------- | -------- | ----------- | --- | ------ | ---------- | ----------- |
| ResNet50-FPN | RetinaNet | 2 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](./retinanet_r50_fpn_1x_coco.yml) |
| ResNet101-FPN| RetinaNet | 2 | 2x | --- | 40.6 | [model](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) | [config](./retinanet_r101_fpn_2x_coco.yml) |
| ResNet50-FPN | RetinaNet | 2 | 2x | --- | 40.8 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r101_distill_r50_2x_coco.pdparams) | [config](./retinanet_r50_fpn_2x_coco.yml)/[slim_config](../slim/distill/retinanet_resnet101_coco_distill.yml) |
**Notes:**
- All above models are trained on COCO train2017 with 8 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`.
......
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
architecture: RetinaNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
RetinaNet:
backbone: ResNet
neck: FPN
head: RetinaHead
ResNet:
depth: 101
variant: b
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
RetinaHead:
conv_feat:
name: RetinaFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: null
use_dcn: false
anchor_generator:
name: RetinaAnchorGenerator
octave_base_scale: 4
scales_per_octave: 3
aspect_ratios: [0.5, 1.0, 2.0]
strides: [8.0, 16.0, 32.0, 64.0, 128.0]
bbox_assigner:
name: MaxIoUAssigner
positive_overlap: 0.5
negative_overlap: 0.4
allow_low_quality: true
loss_class:
name: FocalLoss
gamma: 2.0
alpha: 0.25
loss_weight: 1.0
loss_bbox:
name: SmoothL1Loss
beta: 0.0
loss_weight: 1.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/retinanet_r50_fpn.yml',
'_base_/optimizer_2x.yml',
'_base_/retinanet_reader.yml'
]
weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/retinanet_r101_fpn.yml',
'_base_/optimizer_2x.yml',
'_base_/retinanet_reader.yml'
]
weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams
......@@ -5,6 +5,19 @@
COCO数据集作为目标检测任务的训练目标难度更大,意味着teacher网络会预测出更多的背景bbox,如果直接用teacher的预测输出作为student学习的`soft label`会有严重的类别不均衡问题。解决这个问题需要引入新的方法,详细背景请参考论文:[Object detection at 200 Frames Per Second](https://arxiv.org/abs/1805.06361)
为了确定蒸馏的对象,我们首先需要找到student和teacher网络得到的`x,y,w,h,cls,objness`等Tensor,用teacher得到的结果指导student训练。具体实现可参考[代码](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/ppdet/slim/distill.py)
## FGD模型蒸馏
FGD全称为[Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837v1),是目标检测任务的一种蒸馏方法,FGD蒸馏分为两个部分`Focal``Global``Focal`蒸馏分离图像的前景和背景,让学生模型分别关注教师模型的前景和背景部分特征的关键像素;`Global`蒸馏部分重建不同像素之间的关系并将其从教师转移到学生,以补偿`Focal`蒸馏中丢失的全局信息。试验结果表明,FGD蒸馏算法在基于anchor和anchor free的方法上能有效提升模型精度。
在PaddleDetection中,我们实现了FGD算法,并基于retinaNet算法进行验证,实验结果如下:
| algorithm | model | AP | download|
|:-:| :-: | :-: | :-:|
|retinaNet_r101_fpn_2x | teacher | 40.6 | [download](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) |
|retinaNet_r50_fpn_1x| student | 37.5 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_1x_coco.pdparams) |
|retinaNet_r50_fpn_2x + FGD| student | 40.8 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams) |
## Citations
```
@article{mehta2018object,
......@@ -15,4 +28,12 @@ COCO数据集作为目标检测任务的训练目标难度更大,意味着teac
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@inproceedings{yang2022focal,
title={Focal and global knowledge distillation for detectors},
author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={4643--4652},
year={2022}
}
```
_BASE_: [
'../../retinanet/retinanet_r101_fpn_2x_coco.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams
slim: Distill
slim_method: FGD
distill_loss: FGDFeatureLoss
FGDFeatureLoss:
student_channels: 256
teacher_channels: 256
temp: 0.5
alpha_fgd: 0.001
beta_fgd: 0.0005
gamma_fgd: 0.0005
lambda_fgd: 0.000005
......@@ -362,7 +362,8 @@ class OptimizerBuilder():
else:
params = model.parameters()
train_params = [param for param in params if param.trainable is True]
return op(learning_rate=learning_rate,
parameters=params,
parameters=train_params,
grad_clip=grad_clip,
**optim_args)
......@@ -35,7 +35,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
return cfg
if slim_load_cfg['slim'] == 'Distill':
model = DistillModel(cfg, slim_cfg)
if "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "FGD":
model = FGDDistillModel(cfg, slim_cfg)
else:
model = DistillModel(cfg, slim_cfg)
cfg['model'] = model
cfg['slim_type'] = cfg.slim
elif slim_load_cfg['slim'] == 'OFA':
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, create, load_config
from ppdet.modeling import ops
......@@ -63,6 +64,95 @@ class DistillModel(nn.Layer):
return self.student_model(inputs)
class FGDDistillModel(nn.Layer):
"""
Build FGD distill model.
Args:
cfg: The student config.
slim_cfg: The teacher and distill config.
"""
def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__()
self.student_cfg = cfg
slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg
self.is_loaded_weights = True
self.is_inherit = True
self.student_model = create(self.student_cfg.architecture)
self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.trainable = False
if 'pretrain_weights' in self.student_cfg and self.student_cfg.pretrain_weights:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
self._load_pretrain_weights(self.student_model,
self.teacher_cfg.pretrain_weights)
print("loading teacher weights to student model!")
self._load_pretrain_weights(self.student_model,
self.student_cfg.pretrain_weights)
print("loading student model Done")
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
self._load_pretrain_weights(self.teacher_model,
self.teacher_cfg.pretrain_weights)
print("loading teacher model Done")
self.fgd_loss_dic = self.build_loss(self.loss_cfg.distill_loss)
def _load_pretrain_weights(self, model, weights):
if self.is_loaded_weights:
return
self.start_epoch = 0
load_pretrain_weight(model, weights)
logger.debug("Load weights {} to start training".format(weights))
def build_loss(self,
cfg,
name_list=[
'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
'neck_f_0'
]):
loss_func = dict()
for idx, k in enumerate(name_list):
loss_func[k] = create(cfg)
return loss_func
def forward(self, inputs):
if self.training:
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
loss_dict = {}
for idx, k in enumerate(self.fgd_loss_dic):
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
t_neck_feats[idx], inputs)
loss = self.student_model.head(s_neck_feats, inputs)
for k in loss_dict:
loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k]
return loss
else:
body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats)
bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
@register
class DistillYOLOv3Loss(nn.Layer):
def __init__(self, weight=1000):
......@@ -107,3 +197,254 @@ class DistillYOLOv3Loss(nn.Layer):
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
) * self.weight
return loss
def parameter_init(mode="kaiming", value=0.):
if mode == "kaiming":
weight_attr = paddle.nn.initializer.KaimingUniform()
elif mode == "constant":
weight_attr = paddle.nn.initializer.Constant(value=value)
else:
weight_attr = paddle.nn.initializer.KaimingUniform()
weight_init = ParamAttr(initializer=weight_attr)
return weight_init
@register
class FGDFeatureLoss(nn.Layer):
"""
The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
Paddle version of `Focal and Global Knowledge Distillation for Detectors`
Args:
student_channels(int): The number of channels in the student's FPN feature map. Default to 256.
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256.
temp (float, optional): The temperature coefficient. Defaults to 0.5.
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005
"""
def __init__(self,
student_channels=256,
teacher_channels=256,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005):
super(FGDFeatureLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
kaiming_init = parameter_init("kaiming")
zeros_init = parameter_init("constant", 0.0)
if student_channels != teacher_channels:
self.align = nn.Conv2d(
student_channels,
teacher_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=kaiming_init)
student_channels = teacher_channels
else:
self.align = None
self.conv_mask_s = nn.Conv2D(
student_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.conv_mask_t = nn.Conv2D(
teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.stu_conv_block = nn.Sequential(
nn.Conv2D(
student_channels,
student_channels // 2,
kernel_size=1,
weight_attr=zeros_init),
nn.LayerNorm([student_channels // 2, 1, 1]),
nn.ReLU(),
nn.Conv2D(
student_channels // 2,
student_channels,
kernel_size=1,
weight_attr=zeros_init))
self.tea_conv_block = nn.Sequential(
nn.Conv2D(
teacher_channels,
teacher_channels // 2,
kernel_size=1,
weight_attr=zeros_init),
nn.LayerNorm([teacher_channels // 2, 1, 1]),
nn.ReLU(),
nn.Conv2D(
teacher_channels // 2,
teacher_channels,
kernel_size=1,
weight_attr=zeros_init))
def spatial_channel_attention(self, x, t=0.5):
shape = paddle.shape(x)
N, C, H, W = shape
_f = paddle.abs(x)
spatial_map = paddle.reshape(
paddle.mean(
_f, axis=1, keepdim=True) / t, [N, -1])
spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W
spatial_att = paddle.reshape(spatial_map, [N, H, W])
channel_map = paddle.mean(
paddle.mean(
_f, axis=2, keepdim=False), axis=2, keepdim=False)
channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C
return [spatial_att, channel_att]
def spatial_pool(self, x, mode="teacher"):
batch, channel, width, height = x.shape
x_copy = x
x_copy = paddle.reshape(x_copy, [batch, channel, height * width])
x_copy = x_copy.unsqueeze(1)
if mode.lower() == "student":
context_mask = self.conv_mask_s(x)
else:
context_mask = self.conv_mask_t(x)
context_mask = paddle.reshape(context_mask, [batch, 1, height * width])
context_mask = F.softmax(context_mask, axis=2)
context_mask = context_mask.unsqueeze(-1)
context = paddle.matmul(x_copy, context_mask)
context = paddle.reshape(context, [batch, channel, 1, 1])
return context
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att,
tea_spatial_att):
def _func(a, b):
return paddle.sum(paddle.abs(a - b)) / len(a)
mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
stu_spatial_att, tea_spatial_att)
return mask_loss
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg,
tea_channel_att, tea_spatial_att):
Mask_fg = Mask_fg.unsqueeze(axis=1)
Mask_bg = Mask_bg.unsqueeze(axis=1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1)
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att))
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att))
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att))
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))
fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)
return fg_loss, bg_loss
def relation_loss(self, stu_feature, tea_feature):
context_s = self.spatial_pool(stu_feature, "student")
context_t = self.spatial_pool(tea_feature, "teacher")
out_s = stu_feature + self.stu_conv_block(context_s)
out_t = tea_feature + self.tea_conv_block(context_t)
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
return rela_loss
def mask_value(self, mask, xl, xr, yl, yr, value):
mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value)
return mask
def forward(self, stu_feature, tea_feature, inputs):
"""Forward function.
Args:
stu_feature(Tensor): Bs*C*H*W, student's feature map
tea_feature(Tensor): Bs*C*H*W, teacher's feature map
inputs: The inputs with gt bbox and input shape info.
"""
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.'
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
gt_bboxes = inputs['gt_bbox']
ins_shape = [
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
]
if self.align is not None:
stu_feature = self.align(stu_feature)
N, C, H, W = stu_feature.shape
tea_spatial_att, tea_channel_att = self.spatial_channel_attention(
tea_feature, self.temp)
stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
stu_feature, self.temp)
Mask_fg = paddle.zeros(tea_spatial_att.shape)
Mask_bg = paddle.ones_like(tea_spatial_att)
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]])
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]])
wmin, wmax, hmin, hmax, area = [], [], [], [], []
for i in range(N):
tmp_box = paddle.ones_like(gt_bboxes[i])
tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W
tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W
tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H
tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H
zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32")
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32")
wmin.append(
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero))
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32"))
hmin.append(
paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero))
hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32"))
area_recip = 1.0 / (
hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / (
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
for j in range(len(gt_bboxes[i])):
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j])
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
if paddle.sum(Mask_bg[i]):
Mask_bg[i] /= paddle.sum(Mask_bg[i])
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg,
Mask_bg, tea_channel_att,
tea_spatial_att)
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att,
stu_spatial_att, tea_spatial_att)
rela_loss = self.relation_loss(stu_feature, tea_feature)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册