distill.py 19.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
22
from paddle import ParamAttr
23

M
Manuel Garcia 已提交
24
from ppdet.core.workspace import register, create, load_config
25 26
from ppdet.modeling import ops
from ppdet.utils.checkpoint import load_pretrain_weight
M
Manuel Garcia 已提交
27 28
from ppdet.utils.logger import setup_logger

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
logger = setup_logger(__name__)


class DistillModel(nn.Layer):
    def __init__(self, cfg, slim_cfg):
        super(DistillModel, self).__init__()

        self.student_model = create(cfg.architecture)
        logger.debug('Load student model pretrain_weights:{}'.format(
            cfg.pretrain_weights))
        load_pretrain_weight(self.student_model, cfg.pretrain_weights)

        slim_cfg = load_config(slim_cfg)
        self.teacher_model = create(slim_cfg.architecture)
        self.distill_loss = create(slim_cfg.distill_loss)
        logger.debug('Load teacher model pretrain_weights:{}'.format(
            slim_cfg.pretrain_weights))
        load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)

        for param in self.teacher_model.parameters():
            param.trainable = False

    def parameters(self):
        return self.student_model.parameters()

    def forward(self, inputs):
        if self.training:
            teacher_loss = self.teacher_model(inputs)
            student_loss = self.student_model(inputs)
            loss = self.distill_loss(self.teacher_model, self.student_model)
            student_loss['distill_loss'] = loss
            student_loss['teacher_loss'] = teacher_loss['loss']
            student_loss['loss'] += student_loss['distill_loss']
            return student_loss
        else:
            return self.student_model(inputs)


67 68 69 70 71 72 73 74 75 76
class FGDDistillModel(nn.Layer):
    """
    Build FGD distill model.
    Args:
        cfg: The student config.
        slim_cfg: The teacher and distill config.
    """

    def __init__(self, cfg, slim_cfg):
        super(FGDDistillModel, self).__init__()
77 78 79 80 81 82

        self.is_inherit = True
        # build student model before load slim config
        self.student_model = create(cfg.architecture)
        self.arch = cfg.architecture
        stu_pretrain = cfg['pretrain_weights']
83 84 85
        slim_cfg = load_config(slim_cfg)
        self.teacher_cfg = slim_cfg
        self.loss_cfg = slim_cfg
86
        tea_pretrain = cfg['pretrain_weights']
87 88 89 90 91 92 93

        self.teacher_model = create(self.teacher_cfg.architecture)
        self.teacher_model.eval()

        for param in self.teacher_model.parameters():
            param.trainable = False

94
        if 'pretrain_weights' in cfg and stu_pretrain:
95
            if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
96 97 98 99
                load_pretrain_weight(self.student_model,
                                     self.teacher_cfg.pretrain_weights)
                logger.debug(
                    "Inheriting! loading teacher weights to student model!")
100

101
            load_pretrain_weight(self.student_model, stu_pretrain)
102 103

        if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
104 105
            load_pretrain_weight(self.teacher_model,
                                 self.teacher_cfg.pretrain_weights)
106

107 108 109
        self.fgd_loss_dic = self.build_loss(
            self.loss_cfg.distill_loss,
            name_list=self.loss_cfg['distill_loss_name'])
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    def build_loss(self,
                   cfg,
                   name_list=[
                       'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
                       'neck_f_0'
                   ]):
        loss_func = dict()
        for idx, k in enumerate(name_list):
            loss_func[k] = create(cfg)
        return loss_func

    def forward(self, inputs):
        if self.training:
            s_body_feats = self.student_model.backbone(inputs)
            s_neck_feats = self.student_model.neck(s_body_feats)

            with paddle.no_grad():
                t_body_feats = self.teacher_model.backbone(inputs)
                t_neck_feats = self.teacher_model.neck(t_body_feats)

            loss_dict = {}
            for idx, k in enumerate(self.fgd_loss_dic):
                loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
                                                    t_neck_feats[idx], inputs)
135 136 137
            if self.arch == "RetinaNet":
                loss = self.student_model.head(s_neck_feats, inputs)
            elif self.arch == "PicoDet":
U
user1018 已提交
138 139 140 141 142 143 144
                head_outs = self.student_model.head(
                    s_neck_feats, self.student_model.export_post_process)
                loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
                total_loss = paddle.add_n(list(loss_gfl.values()))
                loss = {}
                loss.update(loss_gfl)
                loss.update({'loss': total_loss})
145 146
            else:
                raise ValueError(f"Unsupported model {self.arch}")
147 148 149 150 151 152 153 154
            for k in loss_dict:
                loss['loss'] += loss_dict[k]
                loss[k] = loss_dict[k]
            return loss
        else:
            body_feats = self.student_model.backbone(inputs)
            neck_feats = self.student_model.neck(body_feats)
            head_outs = self.student_model.head(neck_feats)
155 156 157 158 159
            if self.arch == "RetinaNet":
                bbox, bbox_num = self.student_model.head.post_process(
                    head_outs, inputs['im_shape'], inputs['scale_factor'])
                return {'bbox': bbox, 'bbox_num': bbox_num}
            elif self.arch == "PicoDet":
U
user1018 已提交
160 161 162 163 164 165 166 167
                head_outs = self.student_model.head(
                    neck_feats, self.student_model.export_post_process)
                scale_factor = inputs['scale_factor']
                bboxes, bbox_num = self.student_model.head.post_process(
                    head_outs,
                    scale_factor,
                    export_nms=self.student_model.export_nms)
                return {'bbox': bboxes, 'bbox_num': bbox_num}
168 169
            else:
                raise ValueError(f"Unsupported model {self.arch}")
170 171


172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
@register
class DistillYOLOv3Loss(nn.Layer):
    def __init__(self, weight=1000):
        super(DistillYOLOv3Loss, self).__init__()
        self.weight = weight

    def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
        loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
        loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
        loss_w = paddle.abs(sw - tw)
        loss_h = paddle.abs(sh - th)
        loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
        weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
        return weighted_loss

    def obj_weighted_cls(self, scls, tcls, tobj):
        loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
        weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
        return weighted_loss

    def obj_loss(self, sobj, tobj):
        obj_mask = paddle.cast(tobj > 0., dtype="float32")
        obj_mask.stop_gradient = True
        loss = paddle.mean(
            ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
        return loss

    def forward(self, teacher_model, student_model):
        teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
        student_distill_pairs = student_model.yolo_head.loss.distill_pairs
        distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
        for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
            distill_reg_loss.append(
                self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
                    3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
            distill_cls_loss.append(
                self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
            distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
        distill_reg_loss = paddle.add_n(distill_reg_loss)
        distill_cls_loss = paddle.add_n(distill_cls_loss)
        distill_obj_loss = paddle.add_n(distill_obj_loss)
        loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
                ) * self.weight
        return loss
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236


def parameter_init(mode="kaiming", value=0.):
    if mode == "kaiming":
        weight_attr = paddle.nn.initializer.KaimingUniform()
    elif mode == "constant":
        weight_attr = paddle.nn.initializer.Constant(value=value)
    else:
        weight_attr = paddle.nn.initializer.KaimingUniform()

    weight_init = ParamAttr(initializer=weight_attr)
    return weight_init


@register
class FGDFeatureLoss(nn.Layer):
    """
    The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
    Paddle version of `Focal and Global Knowledge Distillation for Detectors`
   
    Args:
U
user1018 已提交
237 238 239 240 241 242 243 244
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map. 
        temp (float, optional): Temperature coefficient. Defaults to 0.5.
        name (str): the loss name of the layer
        alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
        beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
        gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
        lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
245 246
    """

U
user1018 已提交
247 248 249 250 251 252 253 254 255 256
    def __init__(
            self,
            student_channels,
            teacher_channels,
            name=None,
            temp=0.5,
            alpha_fgd=0.001,
            beta_fgd=0.0005,
            gamma_fgd=0.001,
            lambda_fgd=0.000005):
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
        super(FGDFeatureLoss, self).__init__()
        self.temp = temp
        self.alpha_fgd = alpha_fgd
        self.beta_fgd = beta_fgd
        self.gamma_fgd = gamma_fgd
        self.lambda_fgd = lambda_fgd

        kaiming_init = parameter_init("kaiming")
        zeros_init = parameter_init("constant", 0.0)

        if student_channels != teacher_channels:
            self.align = nn.Conv2d(
                student_channels,
                teacher_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                weight_attr=kaiming_init)
        else:
            self.align = None

        self.conv_mask_s = nn.Conv2D(
U
user1018 已提交
279
            teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
280 281
        self.conv_mask_t = nn.Conv2D(
            teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
U
user1018 已提交
282
        self.channel_add_conv_s = nn.Sequential(
283
            nn.Conv2D(
U
user1018 已提交
284 285
                teacher_channels,
                teacher_channels // 2,
286 287
                kernel_size=1,
                weight_attr=zeros_init),
U
user1018 已提交
288
            nn.LayerNorm([teacher_channels // 2, 1, 1]),
289 290
            nn.ReLU(),
            nn.Conv2D(
U
user1018 已提交
291 292
                teacher_channels // 2,
                teacher_channels,
293 294
                kernel_size=1,
                weight_attr=zeros_init))
U
user1018 已提交
295
        self.channel_add_conv_t = nn.Sequential(
296 297 298 299 300 301 302 303 304 305 306 307 308
            nn.Conv2D(
                teacher_channels,
                teacher_channels // 2,
                kernel_size=1,
                weight_attr=zeros_init),
            nn.LayerNorm([teacher_channels // 2, 1, 1]),
            nn.ReLU(),
            nn.Conv2D(
                teacher_channels // 2,
                teacher_channels,
                kernel_size=1,
                weight_attr=zeros_init))

U
user1018 已提交
309 310 311 312
    def gc_block(self, feature, t=0.5):
        """
        """
        shape = paddle.shape(feature)
313 314
        N, C, H, W = shape

U
user1018 已提交
315 316
        _f = paddle.abs(feature)
        s_map = paddle.reshape(
317 318
            paddle.mean(
                _f, axis=1, keepdim=True) / t, [N, -1])
U
user1018 已提交
319 320
        s_map = F.softmax(s_map, axis=1, dtype="float32") * H * W
        s_attention = paddle.reshape(s_map, [N, H, W])
321

U
user1018 已提交
322
        c_map = paddle.mean(
323 324
            paddle.mean(
                _f, axis=2, keepdim=False), axis=2, keepdim=False)
U
user1018 已提交
325 326
        c_attention = F.softmax(c_map / t, axis=1, dtype="float32") * C
        return s_attention, c_attention
327

U
user1018 已提交
328
    def spatial_pool(self, x, in_type):
329
        batch, channel, width, height = x.shape
U
user1018 已提交
330 331 332 333 334 335 336
        input_x = x
        # [N, C, H * W]
        input_x = paddle.reshape(input_x, [batch, channel, height * width])
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
337 338 339
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
U
user1018 已提交
340
        # [N, 1, H * W]
341
        context_mask = paddle.reshape(context_mask, [batch, 1, height * width])
U
user1018 已提交
342
        # [N, 1, H * W]
343
        context_mask = F.softmax(context_mask, axis=2)
U
user1018 已提交
344
        # [N, 1, H * W, 1]
345
        context_mask = context_mask.unsqueeze(-1)
U
user1018 已提交
346 347 348
        # [N, 1, C, 1]
        context = paddle.matmul(input_x, context_mask)
        # [N, C, 1, 1]
349 350 351 352
        context = paddle.reshape(context, [batch, channel, 1, 1])

        return context

U
user1018 已提交
353 354 355
    def get_mask_loss(self, C_s, C_t, S_s, S_t):
        mask_loss = paddle.sum(paddle.abs((C_s - C_t))) / len(C_s) + paddle.sum(
            paddle.abs((S_s - S_t))) / len(S_s)
356 357
        return mask_loss

U
user1018 已提交
358 359
    def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s,
                     S_t):
360 361 362
        Mask_fg = Mask_fg.unsqueeze(axis=1)
        Mask_bg = Mask_bg.unsqueeze(axis=1)

U
user1018 已提交
363 364
        C_t = C_t.unsqueeze(axis=-1)
        C_t = C_t.unsqueeze(axis=-1)
365

U
user1018 已提交
366
        S_t = S_t.unsqueeze(axis=1)
367

U
user1018 已提交
368 369
        fea_t = paddle.multiply(preds_T, paddle.sqrt(S_t))
        fea_t = paddle.multiply(fea_t, paddle.sqrt(C_t))
370 371 372
        fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
        bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))

U
user1018 已提交
373 374
        fea_s = paddle.multiply(preds_S, paddle.sqrt(S_t))
        fea_s = paddle.multiply(fea_s, paddle.sqrt(C_t))
375 376 377 378 379 380 381 382
        fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
        bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))

        fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
        bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)

        return fg_loss, bg_loss

U
user1018 已提交
383 384 385 386 387 388
    def get_rela_loss(self, preds_S, preds_T):
        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)

        out_s = preds_S
        out_t = preds_T
389

U
user1018 已提交
390 391 392 393 394
        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s

        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t
395 396 397 398 399 400 401 402 403

        rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)

        return rela_loss

    def mask_value(self, mask, xl, xr, yl, yr, value):
        mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value)
        return mask

U
user1018 已提交
404
    def forward(self, preds_S, preds_T, inputs):
405 406
        """Forward function.
        Args:
U
user1018 已提交
407 408
            preds_S(Tensor): Bs*C*H*W, student's feature map
            preds_T(Tensor): Bs*C*H*W, teacher's feature map
409 410
            inputs: The inputs with gt bbox and input shape info.
        """
U
user1018 已提交
411 412
        assert preds_S.shape[-2:] == preds_T.shape[-2:], \
            f'The shape of Student feature {preds_S.shape} and Teacher feature {preds_T.shape} should be the same.'
413
        gt_bboxes = inputs['gt_bbox']
U
user1018 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        assert len(gt_bboxes) == preds_S.shape[0], "error"

        # select index 
        index_gt = []
        for i in range(len(gt_bboxes)):
            if gt_bboxes[i].size > 2:
                index_gt.append(i)
        index_gt_t = paddle.to_tensor(index_gt)  # to tensor
        preds_S = paddle.index_select(preds_S, index_gt_t)
        preds_T = paddle.index_select(preds_T, index_gt_t)
        assert preds_S.shape == preds_T.shape, "error"

        img_metas_tmp = [{
            'img_shape': inputs['im_shape'][i]
        } for i in range(inputs['im_shape'].shape[0])]
        img_metas = [img_metas_tmp[c] for c in index_gt]
        gt_bboxes = [gt_bboxes[c] for c in index_gt]
        assert len(gt_bboxes) == len(img_metas), "error"

        assert len(gt_bboxes) == preds_T.shape[0], "error"
434 435

        if self.align is not None:
U
user1018 已提交
436
            preds_S = self.align(preds_S)
437

U
user1018 已提交
438
        N, C, H, W = preds_S.shape
439

U
user1018 已提交
440 441
        S_attention_t, C_attention_t = self.gc_block(preds_T, self.temp)
        S_attention_s, C_attention_s = self.gc_block(preds_S, self.temp)
442

U
user1018 已提交
443 444 445 446
        Mask_fg = paddle.zeros(S_attention_t.shape)
        Mask_bg = paddle.ones_like(S_attention_t)
        one_tmp = paddle.ones([*S_attention_t.shape[1:]])
        zero_tmp = paddle.zeros([*S_attention_t.shape[1:]])
447 448
        wmin, wmax, hmin, hmax, area = [], [], [], [], []
        for i in range(N):
U
user1018 已提交
449 450 451 452 453 454 455 456 457 458 459
            new_boxxes = paddle.ones_like(gt_bboxes[i])
            new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][
                1] * W
            new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][
                1] * W
            new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][
                0] * H
            new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][
                0] * H
            zero = paddle.zeros_like(new_boxxes[:, 0], dtype="int32")
            ones = paddle.ones_like(new_boxxes[:, 2], dtype="int32")
460
            wmin.append(
U
user1018 已提交
461 462 463
                paddle.cast(paddle.floor(new_boxxes[:, 0]), "int32").maximum(
                    zero))
            wmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 2]), "int32"))
464
            hmin.append(
U
user1018 已提交
465 466 467
                paddle.cast(paddle.floor(new_boxxes[:, 1]), "int32").maximum(
                    zero))
            hmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 3]), "int32"))
468

U
user1018 已提交
469
            area = 1.0 / (
470 471 472 473 474
                hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / (
                    wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
            for j in range(len(gt_bboxes[i])):
                Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
                                             hmax[i][j] + 1, wmin[i][j],
U
user1018 已提交
475
                                             wmax[i][j] + 1, area[0][j])
476 477 478 479 480
            Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)

            if paddle.sum(Mask_bg[i]):
                Mask_bg[i] /= paddle.sum(Mask_bg[i])

U
user1018 已提交
481 482 483 484 485 486 487
        fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
                                             C_attention_s, C_attention_t,
                                             S_attention_s, S_attention_t)
        mask_loss = self.get_mask_loss(C_attention_s, C_attention_t,
                                       S_attention_s, S_attention_t)
        rela_loss = self.get_rela_loss(preds_S, preds_T)

488 489 490 491 492

        loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
               + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss

        return loss