未验证 提交 b52580df 编写于 作者: H huangjun12 提交者: GitHub

add LD distillation for GFL model (#7101)

* add code

* add ld-vlr

* refractor code

* refine code

* add doc

* refine details

* refine doc

* update download url

* rename config file for ci

* rename config file in README for CI
上级 fa67fb9f
......@@ -12,6 +12,7 @@ We reproduce the object detection results in the paper [Generalized Focal Loss:
| ResNet101-vd | GFL | 2 | 2x | ---- | 46.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r101vd_fpn_mstrain_2x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) |
| ResNet34-vd | GFL | 2 | 1x | ---- | 40.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r34vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r34vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r34vd_1x_coco.yml) |
| ResNet18-vd | GFL | 2 | 1x | ---- | 36.6 | [model](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r18vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r18vd_1x_coco.yml) |
| ResNet18-vd | GFL + [LD](../slim/README.md) | 2 | 1x | ---- | 38.2 | [model](https://bj.bcebos.com/v1/paddledet/models/gfl_slim_ld_r18vd_1x_coco.pdparams) | [log](https://bj.bcebos.com/v1/paddledet/logs/train_gfl_slim_ld_r18vd_1x_coco.log) | [config1](./gfl_slim_ld_r18vd_1x_coco.yml), [config2](../slim/distill/gfl_ld_distill.yml) |
| ResNet50 | GFLv2 | 2 | 1x | ---- | 41.2 | [model](https://paddledet.bj.bcebos.com/models/gflv2_r50_fpn_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gflv2_r50_fpn_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gflv2_r50_fpn_1x_coco.yml) |
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/gfl_reader.yml',
]
weights: output/gfl_r18vd_1x_coco/model_final
find_unused_parameters: True
architecture: GFL
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams
GFL:
backbone: ResNet
neck: FPN
head: LDGFLHead
ResNet:
depth: 18
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
LDGFLHead: # new head
conv_feat:
name: FCOSFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: false
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
reg_max: 16
loss_class:
name: QualityFocalLoss
use_sigmoid: True
beta: 2.0
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
loss_ld:
name: KnowledgeDistillationKLDivLoss
loss_weight: 0.25
T: 10
loss_ld_vlr:
name: KnowledgeDistillationKLDivLoss
loss_weight: 0.25
T: 10
loss_kd:
name: KnowledgeDistillationKLDivLoss
loss_weight: 10
T: 2
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.025
nms_threshold: 0.6
......@@ -17,6 +17,14 @@ FGD全称为[Focal and Global Knowledge Distillation for Detectors](https://arxi
|retinaNet_r50_fpn_2x + FGD| student | 40.8 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams) |
## LD模型蒸馏
LD全称为[Localization Distillation for Dense Object Detection](https://arxiv.org/abs/2102.12252),将回归框表示为概率分布,把分类任务的KD用在定位任务上,并且使用因地制宜、分而治之的策略,在不同的区域分别学习分类知识与定位知识。在PaddleDetection中,我们实现了LD算法,并基于GFL模型进行验证,实验结果如下:
| algorithm | model | AP | download|
|:-:| :-: | :-: | :-:|
| GFL_ResNet101-vd | teacher | 46.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams), [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) |
| GFL_ResNet18-vd | student | 36.6 | [model](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams), [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r18vd_1x_coco.yml) |
| GFL_ResNet18-vd + LD | student | 38.2 | [model](https://bj.bcebos.com/v1/paddledet/models/gfl_slim_ld_r18vd_1x_coco.pdparams), [config1](../../gfl/gfl_slim_ld_r18vd_1x_coco.yml), [config2](./gfl_ld_distill.yml) |
## Citations
```
......@@ -36,4 +44,11 @@ FGD全称为[Focal and Global Knowledge Distillation for Detectors](https://arxi
pages={4643--4652},
year={2022}
}
@Inproceedings{zheng2022LD,
title={Localization Distillation for Dense Object Detection},
author= {Zheng, Zhaohui and Ye, Rongguang and Wang, Ping and Ren, Dongwei and Zuo, Wangmeng and Hou, Qibin and Cheng, Mingming},
booktitle={CVPR},
year={2022}
}
```
_BASE_: [
'../../gfl/gfl_r18vd_1x_coco.yml',
]
# teacher pretrain model
pretrain_weights: https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams
slim: Distill
slim_method: LD
ResNet:
depth: 101
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
......@@ -43,7 +43,8 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
Returns:
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
"""
assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode)
assert mode in ['iou', 'iof', 'giou', 'diou'], 'Unsupported mode {}'.format(
mode)
# Either the boxes are empty or the length of boxes's last dimenstion is 4
assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0)
assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0)
......@@ -83,6 +84,13 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
if mode == 'giou':
enclosed_lt = np.minimum(bboxes1[..., :2], bboxes2[..., :2])
enclosed_rb = np.maximum(bboxes1[..., 2:], bboxes2[..., 2:])
if mode == 'diou':
enclosed_lt = np.minimum(bboxes1[..., :2], bboxes2[..., :2])
enclosed_rb = np.maximum(bboxes1[..., 2:], bboxes2[..., 2:])
b1_x1, b1_y1 = bboxes1[..., 0], bboxes1[..., 1]
b1_x2, b1_y2 = bboxes1[..., 2], bboxes1[..., 3]
b2_x1, b2_y1 = bboxes2[..., 0], bboxes2[..., 1]
b2_x2, b2_y2 = bboxes2[..., 2], bboxes2[..., 3]
else:
lt = np.maximum(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
......@@ -101,6 +109,15 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
bboxes2[..., None, :, :2])
enclosed_rb = np.maximum(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:])
if mode == 'diou':
enclosed_lt = np.minimum(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2])
enclosed_rb = np.maximum(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:])
b1_x1, b1_y1 = bboxes1[..., :, None, 0], bboxes1[..., :, None, 1]
b1_x2, b1_y2 = bboxes1[..., :, None, 2], bboxes1[..., :, None, 3]
b2_x1, b2_y1 = bboxes2[..., None, :, 0], bboxes2[..., None, :, 1]
b2_x2, b2_y2 = bboxes2[..., None, :, 2], bboxes2[..., None, :, 3]
eps = np.array([eps])
union = np.maximum(union, eps)
......@@ -108,18 +125,32 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
if mode in ['iou', 'iof']:
return ious
# calculate gious
enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
enclose_area = np.maximum(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return gious
if mode in ['giou']:
enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
enclose_area = np.maximum(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return gious
if mode in ['diou']:
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
rho2 = left + right
enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
enclose_c = enclose_wh[..., 0]**2 + enclose_wh[..., 1]**2
enclose_c = np.maximum(enclose_c, eps)
dious = ious - rho2 / enclose_c
return dious
def topk_(input, k, axis=1, largest=True):
x = -input if largest else input
if axis == 0:
row_index = np.arange(input.shape[1 - axis])
topk_index = np.argpartition(x, k, axis=axis)[0:k, :]
if k == x.shape[0]: # argpartition requires index < len(input)
topk_index = np.argpartition(x, k - 1, axis=axis)[0:k, :]
else:
topk_index = np.argpartition(x, k, axis=axis)[0:k, :]
topk_data = x[topk_index, row_index]
topk_index_sort = np.argsort(topk_data, axis=axis)
......@@ -267,3 +298,124 @@ class ATSSAssigner(object):
-np.inf] = argmax_overlaps[max_overlaps != -np.inf] + 1
return assigned_gt_inds, max_overlaps
def get_vlr_region(self,
bboxes,
num_level_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None):
"""get vlr region for ld distillation.
Args:
bboxes (np.array): Bounding boxes to be assigned, shape(n, 4).
num_level_bboxes (List): num of bboxes in each level
gt_bboxes (np.array): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (np.array, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (np.array, optional): Label of gt_bboxes, shape (k, ).
"""
bboxes = bboxes[:, :4]
num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0]
# compute iou between all bbox and gt
overlaps = bbox_overlaps(bboxes, gt_bboxes)
# compute diou between all bbox and gt
diou = bbox_overlaps(bboxes, gt_bboxes, mode='diou')
# assign 0 by default
assigned_gt_inds = np.zeros((num_bboxes, ), dtype=np.int64)
vlr_region_iou = (assigned_gt_inds + 0).astype(np.float32)
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = np.zeros((num_bboxes, ))
if num_gt == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0
if not np.any(gt_labels):
assigned_labels = None
else:
assigned_labels = -np.ones((num_bboxes, ), dtype=np.int64)
return assigned_gt_inds, max_overlaps
# compute center distance between all bbox and gt
gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
gt_points = np.stack((gt_cx, gt_cy), axis=1)
bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
bboxes_points = np.stack((bboxes_cx, bboxes_cy), axis=1)
distances = np.sqrt(
np.power((bboxes_points[:, None, :] - gt_points[None, :, :]), 2)
.sum(-1))
# Selecting candidates based on the center distance
candidate_idxs = []
candidate_idxs_t = []
start_idx = 0
for bboxes_per_level in num_level_bboxes:
# on each pyramid level, for each gt,
# select k bbox whose center are closest to the gt center
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_t = min(self.topk, bboxes_per_level)
selectable_k = bboxes_per_level #k for all
_, topt_idxs_per_level = topk_(
distances_per_level, selectable_t, axis=0, largest=False)
_, topk_idxs_per_level = topk_(
distances_per_level, selectable_k, axis=0, largest=False)
candidate_idxs_t.append(topt_idxs_per_level + start_idx)
candidate_idxs.append(topk_idxs_per_level + start_idx)
start_idx = end_idx
candidate_idxs_t = np.concatenate(candidate_idxs_t, axis=0)
candidate_idxs = np.concatenate(candidate_idxs, axis=0)
# get corresponding iou for the these candidates, and compute the
# mean and std, set mean + std as the iou threshold
candidate_overlaps_t = overlaps[candidate_idxs_t, np.arange(num_gt)]
# compute tdiou
t_diou = diou[candidate_idxs, np.arange(num_gt)]
overlaps_mean_per_gt = candidate_overlaps_t.mean(0)
overlaps_std_per_gt = candidate_overlaps_t.std(
0, ddof=1) # NOTE: use Bessel correction
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
# compute region
is_pos = (t_diou < overlaps_thr_per_gt[None, :]) & (
t_diou >= 0.25 * overlaps_thr_per_gt[None, :])
# limit the positive sample's center in gt
for gt_idx in range(num_gt):
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
candidate_idxs = candidate_idxs.reshape(-1)
# if an anchor box is assigned to multiple gts,
# the one with the highest IoU will be selected.
overlaps_inf = -np.inf * np.ones_like(overlaps).T.reshape(-1)
index = candidate_idxs.reshape(-1)[is_pos.reshape(-1)]
overlaps_inf[index] = overlaps.T.reshape(-1)[index]
overlaps_inf = overlaps_inf.reshape(num_gt, -1).T
max_overlaps = overlaps_inf.max(axis=1)
argmax_overlaps = overlaps_inf.argmax(axis=1)
overlaps_inf = -np.inf * np.ones_like(overlaps).T.reshape(-1)
overlaps_inf = overlaps_inf.reshape(num_gt, -1).T
assigned_gt_inds[max_overlaps !=
-np.inf] = argmax_overlaps[max_overlaps != -np.inf] + 1
vlr_region_iou[max_overlaps !=
-np.inf] = max_overlaps[max_overlaps != -np.inf] + 0
return vlr_region_iou
......@@ -574,6 +574,11 @@ class Gt2GFLTarget(BaseOperator):
assign_gt_inds, _ = self.assigner(grid_cells, num_level_cells,
gt_bboxes, gt_bboxes_ignore,
gt_labels)
vlr_region = self.assigner.get_vlr_region(grid_cells, num_level_cells,
gt_bboxes, gt_bboxes_ignore,
gt_labels)
pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample(
assign_gt_inds, gt_bboxes)
......@@ -600,6 +605,7 @@ class Gt2GFLTarget(BaseOperator):
sample['label_weights'] = label_weights
sample['bbox_targets'] = bbox_targets
sample['pos_num'] = max(pos_inds.size, 1)
sample['vlr_regions'] = vlr_region
sample.pop('is_crowd', None)
sample.pop('difficult', None)
sample.pop('gt_class', None)
......
......@@ -34,6 +34,7 @@ from . import tood_head
from . import retina_head
from . import ppyoloe_head
from . import fcosr_head
from . import ld_gfl_head
from .bbox_head import *
from .mask_head import *
......@@ -57,3 +58,4 @@ from .tood_head import *
from .retina_head import *
from .ppyoloe_head import *
from .fcosr_head import *
from .ld_gfl_head import *
......@@ -55,7 +55,6 @@ class Integral(nn.Layer):
This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
P(y_i) denotes the softmax vector that represents the discrete distribution
y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
Args:
reg_max (int): The maximal value of the discrete set. Default: 16. You
may want to reset it according to your new dataset or related
......@@ -88,7 +87,6 @@ class Integral(nn.Layer):
@register
class DGQP(nn.Layer):
"""Distribution-Guided Quality Predictor of GFocal head
Args:
reg_topk (int): top-k statistics of distribution to guide LQE
reg_channels (int): hidden layer unit to generate LQE
......@@ -437,4 +435,4 @@ class GFLHead(nn.Layer):
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
return bbox_pred, bbox_num
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on:
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/ld_head.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox
from ppdet.data.transform.atss_assigner import bbox_overlaps
from .gfl_head import GFLHead
@register
class LDGFLHead(GFLHead):
"""
GFLHead for LD distill
Args:
conv_feat (object): Instance of 'FCOSFeat'
num_classes (int): Number of classes
fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer
loss_class (object): Instance of QualityFocalLoss.
loss_dfl (object): Instance of DistributionFocalLoss.
loss_bbox (object): Instance of bbox loss.
reg_max: Max value of integral set :math: `{0, ..., reg_max}`
n QFL setting. Default: 16.
"""
__inject__ = [
'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
'loss_ld', 'loss_ld_vlr', 'loss_kd', 'nms'
]
__shared__ = ['num_classes']
def __init__(self,
conv_feat='FCOSFeat',
dgqp_module=None,
num_classes=80,
fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01,
loss_class='QualityFocalLoss',
loss_dfl='DistributionFocalLoss',
loss_bbox='GIoULoss',
loss_ld='KnowledgeDistillationKLDivLoss',
loss_ld_vlr='KnowledgeDistillationKLDivLoss',
loss_kd='KnowledgeDistillationKLDivLoss',
reg_max=16,
feat_in_chan=256,
nms=None,
nms_pre=1000,
cell_offset=0):
super(LDGFLHead, self).__init__(
conv_feat=conv_feat,
dgqp_module=dgqp_module,
num_classes=num_classes,
fpn_stride=fpn_stride,
prior_prob=prior_prob,
loss_class=loss_class,
loss_dfl=loss_dfl,
loss_bbox=loss_bbox,
reg_max=reg_max,
feat_in_chan=feat_in_chan,
nms=nms,
nms_pre=nms_pre,
cell_offset=cell_offset)
self.loss_ld = loss_ld
self.loss_kd = loss_kd
self.loss_ld_vlr = loss_ld_vlr
def forward(self, fpn_feats):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
cls_logits_list = []
bboxes_reg_list = []
for stride, scale_reg, fpn_feat in zip(self.fpn_stride,
self.scales_regs, fpn_feats):
conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat)
cls_score = self.gfl_head_cls(conv_cls_feat)
bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat))
if self.dgqp_module:
quality_score = self.dgqp_module(bbox_pred)
cls_score = F.sigmoid(cls_score) * quality_score
if not self.training:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = paddle.shape(cls_score)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
# NOTE: If keep_ratio=False and image shape value that
# multiples of 32, distance2bbox not set max_shapes parameter
# to speed up model prediction. If need to set max_shapes,
# please use inputs['im_shape'].
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
cls_logits_list.append(cls_score)
bboxes_reg_list.append(bbox_pred)
return (cls_logits_list, bboxes_reg_list)
def get_loss(self, gfl_head_outs, gt_meta, soft_label_list,
soft_targets_list):
cls_logits, bboxes_reg = gfl_head_outs
num_level_anchors = [
featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits
]
grid_cells_list = self._images_to_levels(gt_meta['grid_cells'],
num_level_anchors)
labels_list = self._images_to_levels(gt_meta['labels'],
num_level_anchors)
label_weights_list = self._images_to_levels(gt_meta['label_weights'],
num_level_anchors)
bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'],
num_level_anchors)
# vlr regions
vlr_regions_list = self._images_to_levels(gt_meta['vlr_regions'],
num_level_anchors)
num_total_pos = sum(gt_meta['pos_num'])
try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
)) / paddle.distributed.get_world_size()
except:
num_total_pos = max(num_total_pos, 1)
loss_bbox_list, loss_dfl_list, loss_qfl_list, loss_ld_list, avg_factor = [], [], [], [], []
loss_ld_vlr_list, loss_kd_list = [], []
for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride, soft_targets,\
soft_label, vlr_region in zip(
cls_logits, bboxes_reg, grid_cells_list, labels_list,
label_weights_list, bbox_targets_list, self.fpn_stride, soft_targets_list,
soft_label_list, vlr_regions_list):
grid_cells = grid_cells.reshape([-1, 4])
cls_score = cls_score.transpose([0, 2, 3, 1]).reshape(
[-1, self.cls_out_channels])
bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
[-1, 4 * (self.reg_max + 1)])
soft_targets = soft_targets.transpose([0, 2, 3, 1]).reshape(
[-1, 4 * (self.reg_max + 1)])
soft_label = soft_label.transpose([0, 2, 3, 1]).reshape(
[-1, self.cls_out_channels])
# feture im
# teacher_x = teacher_x.transpose([0, 2, 3, 1]).reshape([-1, 256])
# x = x.transpose([0, 2, 3, 1]).reshape([-1, 256])
bbox_targets = bbox_targets.reshape([-1, 4])
labels = labels.reshape([-1])
label_weights = label_weights.reshape([-1])
vlr_region = vlr_region.reshape([-1])
bg_class_ind = self.num_classes
pos_inds = paddle.nonzero(
paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
as_tuple=False).squeeze(1)
score = np.zeros(labels.shape)
remain_inds = (vlr_region > 0).nonzero()
if len(pos_inds) > 0:
pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0)
pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0)
pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0)
pos_grid_cell_centers = self._grid_cells_to_center(
pos_grid_cells) / stride
weight_targets = F.sigmoid(cls_score.detach())
weight_targets = paddle.gather(
weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
pos_bbox_pred_corners)
pos_decode_bbox_targets = pos_bbox_targets / stride
bbox_iou = bbox_overlaps(
pos_decode_bbox_pred.detach().numpy(),
pos_decode_bbox_targets.detach().numpy(),
is_aligned=True)
score[pos_inds.numpy()] = bbox_iou
pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1])
pos_soft_targets = paddle.gather(soft_targets, pos_inds, axis=0)
soft_corners = pos_soft_targets.reshape([-1, self.reg_max + 1])
target_corners = bbox2distance(pos_grid_cell_centers,
pos_decode_bbox_targets,
self.reg_max).reshape([-1])
# regression loss
loss_bbox = paddle.sum(
self.loss_bbox(pos_decode_bbox_pred,
pos_decode_bbox_targets) * weight_targets)
# dfl loss
loss_dfl = self.loss_dfl(
pred_corners,
target_corners,
weight=weight_targets.expand([-1, 4]).reshape([-1]),
avg_factor=4.0)
# ld loss
loss_ld = self.loss_ld(
pred_corners,
soft_corners,
weight=weight_targets.expand([-1, 4]).reshape([-1]),
avg_factor=4.0)
loss_kd = self.loss_kd(
paddle.gather(
cls_score, pos_inds, axis=0),
paddle.gather(
soft_label, pos_inds, axis=0),
weight=paddle.gather(
label_weights, pos_inds, axis=0),
avg_factor=pos_inds.shape[0])
else:
loss_bbox = bbox_pred.sum() * 0
loss_dfl = bbox_pred.sum() * 0
loss_ld = bbox_pred.sum() * 0
loss_kd = bbox_pred.sum() * 0
weight_targets = paddle.to_tensor([0], dtype='float32')
if len(remain_inds) > 0:
neg_pred_corners = bbox_pred[remain_inds].reshape(
[-1, self.reg_max + 1])
neg_soft_corners = soft_targets[remain_inds].reshape(
[-1, self.reg_max + 1])
remain_targets = vlr_region[remain_inds]
loss_ld_vlr = self.loss_ld_vlr(
neg_pred_corners,
neg_soft_corners,
weight=remain_targets.expand([-1, 4]).reshape([-1]),
avg_factor=16.0)
else:
loss_ld_vlr = bbox_pred.sum() * 0
# qfl loss
score = paddle.to_tensor(score)
loss_qfl = self.loss_qfl(
cls_score, (labels, score),
weight=label_weights,
avg_factor=num_total_pos)
loss_bbox_list.append(loss_bbox)
loss_dfl_list.append(loss_dfl)
loss_qfl_list.append(loss_qfl)
loss_ld_list.append(loss_ld)
loss_ld_vlr_list.append(loss_ld_vlr)
loss_kd_list.append(loss_kd)
avg_factor.append(weight_targets.sum())
avg_factor = sum(avg_factor) # + 1e-6
try:
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
avg_factor = max(avg_factor.item(), 1)
if avg_factor <= 0:
loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
loss_bbox = paddle.to_tensor(
0, dtype='float32', stop_gradient=False)
loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
loss_ld = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
loss_ld_vlr = paddle.to_tensor(
0, dtype='float32', stop_gradient=False)
loss_kd = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
else:
losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list))
losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list))
loss_qfl = sum(loss_qfl_list)
loss_bbox = sum(losses_bbox)
loss_dfl = sum(losses_dfl)
loss_ld = sum(loss_ld_list)
loss_ld_vlr = sum(loss_ld_vlr_list)
loss_kd = sum(loss_kd_list)
loss_states = dict(
loss_qfl=loss_qfl,
loss_bbox=loss_bbox,
loss_dfl=loss_dfl,
loss_ld=loss_ld,
loss_ld_vlr=loss_ld_vlr,
loss_kd=loss_kd)
return loss_states
......@@ -31,6 +31,7 @@ from ppdet.utils.checkpoint import load_pretrain_weight
def build_slim_model(cfg, slim_cfg, mode='train'):
with open(slim_cfg) as f:
slim_load_cfg = yaml.load(f, Loader=yaml.Loader)
if mode != 'train' and slim_load_cfg['slim'] == 'Distill':
return cfg
......@@ -38,6 +39,9 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
if "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "FGD":
model = FGDDistillModel(cfg, slim_cfg)
elif "slim_method" in slim_load_cfg and slim_load_cfg[
'slim_method'] == "LD":
model = LDDistillModel(cfg, slim_cfg)
else:
model = DistillModel(cfg, slim_cfg)
cfg['model'] = model
......
......@@ -39,6 +39,7 @@ class DistillModel(nn.Layer):
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
slim_cfg = load_config(slim_cfg)
self.teacher_model = create(slim_cfg.architecture)
self.distill_loss = create(slim_cfg.distill_loss)
logger.debug('Load teacher model pretrain_weights:{}'.format(
......@@ -488,3 +489,144 @@ class FGDFeatureLoss(nn.Layer):
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss
class LDDistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg):
super(LDDistillModel, self).__init__()
self.student_model = create(cfg.architecture)
logger.debug('Load student model pretrain_weights:{}'.format(
cfg.pretrain_weights))
load_pretrain_weight(self.student_model, cfg.pretrain_weights)
slim_cfg = load_config(slim_cfg) #rewrite student cfg
self.teacher_model = create(slim_cfg.architecture)
logger.debug('Load teacher model pretrain_weights:{}'.format(
slim_cfg.pretrain_weights))
load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
for param in self.teacher_model.parameters():
param.trainable = False
def parameters(self):
return self.student_model.parameters()
def forward(self, inputs):
if self.training:
with paddle.no_grad():
t_body_feats = self.teacher_model.backbone(inputs)
t_neck_feats = self.teacher_model.neck(t_body_feats)
t_head_outs = self.teacher_model.head(t_neck_feats)
#student_loss = self.student_model(inputs)
s_body_feats = self.student_model.backbone(inputs)
s_neck_feats = self.student_model.neck(s_body_feats)
s_head_outs = self.student_model.head(s_neck_feats)
soft_label_list = t_head_outs[0]
soft_targets_list = t_head_outs[1]
student_loss = self.student_model.head.get_loss(
s_head_outs, inputs, soft_label_list, soft_targets_list)
total_loss = paddle.add_n(list(student_loss.values()))
student_loss['loss'] = total_loss
return student_loss
else:
return self.student_model(inputs)
@register
class KnowledgeDistillationKLDivLoss(nn.Layer):
"""Loss function for knowledge distilling using KL divergence.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
"""
def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(KnowledgeDistillationKLDivLoss, self).__init__()
assert reduction in ('none', 'mean', 'sum')
assert T >= 1
self.reduction = reduction
self.loss_weight = loss_weight
self.T = T
def knowledge_distillation_kl_div_loss(self,
pred,
soft_label,
T,
detach_target=True):
r"""Loss function for knowledge distilling using KL divergence.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert pred.shape == soft_label.shape
target = F.softmax(soft_label / T, axis=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div(
F.log_softmax(
pred / T, axis=1), target, reduction='none').mean(1) * (T * T)
return kd_loss
def forward(self,
pred,
soft_label,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (reduction_override
if reduction_override else self.reduction)
loss_kd_out = self.knowledge_distillation_kl_div_loss(
pred, soft_label, T=self.T)
if weight is not None:
loss_kd_out = weight * loss_kd_out
if avg_factor is None:
if reduction == 'none':
loss = loss_kd_out
elif reduction == 'mean':
loss = loss_kd_out.mean()
elif reduction == 'sum':
loss = loss_kd_out.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss_kd_out.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
loss_kd = self.loss_weight * loss
return loss_kd
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册