diff --git a/configs/retinanet/README.md b/configs/retinanet/README.md index 8a9ec59b9ca377576c0f0b5f8ec0d01ff7684a17..8490dd907ddc4443d717e62076884a19e472d401 100644 --- a/configs/retinanet/README.md +++ b/configs/retinanet/README.md @@ -5,6 +5,10 @@ | Backbone | Model | imgs/GPU | lr schedule | FPS | Box AP | download | config | | ------------ | --------- | -------- | ----------- | --- | ------ | ---------- | ----------- | | ResNet50-FPN | RetinaNet | 2 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](./retinanet_r50_fpn_1x_coco.yml) | +| ResNet101-FPN| RetinaNet | 2 | 2x | --- | 40.6 | [model](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) | [config](./retinanet_r101_fpn_2x_coco.yml) | +| ResNet50-FPN | RetinaNet | 2 | 2x | --- | 40.8 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r101_distill_r50_2x_coco.pdparams) | [config](./retinanet_r50_fpn_2x_coco.yml)/[slim_config](../slim/distill/retinanet_resnet101_coco_distill.yml) | + + **Notes:** - All above models are trained on COCO train2017 with 8 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`. diff --git a/configs/retinanet/_base_/optimizer_2x.yml b/configs/retinanet/_base_/optimizer_2x.yml new file mode 100644 index 0000000000000000000000000000000000000000..61841433417b9fcc6f29a6c71a72ba23406b55ad --- /dev/null +++ b/configs/retinanet/_base_/optimizer_2x.yml @@ -0,0 +1,19 @@ +epoch: 24 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/retinanet/_base_/retinanet_r101_fpn.yml b/configs/retinanet/_base_/retinanet_r101_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..ae5595769d940c2ecb5b857fdc8970da76d572ab --- /dev/null +++ b/configs/retinanet/_base_/retinanet_r101_fpn.yml @@ -0,0 +1,57 @@ +architecture: RetinaNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams + +RetinaNet: + backbone: ResNet + neck: FPN + head: RetinaHead + +ResNet: + depth: 101 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +RetinaHead: + conv_feat: + name: RetinaFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: null + use_dcn: false + anchor_generator: + name: RetinaAnchorGenerator + octave_base_scale: 4 + scales_per_octave: 3 + aspect_ratios: [0.5, 1.0, 2.0] + strides: [8.0, 16.0, 32.0, 64.0, 128.0] + bbox_assigner: + name: MaxIoUAssigner + positive_overlap: 0.5 + negative_overlap: 0.4 + allow_low_quality: true + loss_class: + name: FocalLoss + gamma: 2.0 + alpha: 0.25 + loss_weight: 1.0 + loss_bbox: + name: SmoothL1Loss + beta: 0.0 + loss_weight: 1.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/configs/retinanet/retinanet_r101_distill_r50_2x_coco.yml b/configs/retinanet/retinanet_r101_distill_r50_2x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..bb72cda8e99ac6a597ea5fc9b113378f7954bac3 --- /dev/null +++ b/configs/retinanet/retinanet_r101_distill_r50_2x_coco.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/retinanet_r50_fpn.yml', + '_base_/optimizer_2x.yml', + '_base_/retinanet_reader.yml' +] + +weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams diff --git a/configs/retinanet/retinanet_r101_fpn_2x_coco.yml b/configs/retinanet/retinanet_r101_fpn_2x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..854def4ad82ebcc48d904e665b856ec47655d167 --- /dev/null +++ b/configs/retinanet/retinanet_r101_fpn_2x_coco.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/retinanet_r101_fpn.yml', + '_base_/optimizer_2x.yml', + '_base_/retinanet_reader.yml' +] + +weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams diff --git a/configs/slim/distill/README.md b/configs/slim/distill/README.md index da5795764cec02ea384f8e063f918b56b4f2b9bb..a08bf35fc350cd7c4284bb00ffdf641d80de9798 100644 --- a/configs/slim/distill/README.md +++ b/configs/slim/distill/README.md @@ -5,6 +5,19 @@ COCO数据集作为目标检测任务的训练目标难度更大,意味着teacher网络会预测出更多的背景bbox,如果直接用teacher的预测输出作为student学习的`soft label`会有严重的类别不均衡问题。解决这个问题需要引入新的方法,详细背景请参考论文:[Object detection at 200 Frames Per Second](https://arxiv.org/abs/1805.06361)。 为了确定蒸馏的对象,我们首先需要找到student和teacher网络得到的`x,y,w,h,cls,objness`等Tensor,用teacher得到的结果指导student训练。具体实现可参考[代码](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/ppdet/slim/distill.py) + +## FGD模型蒸馏 + +FGD全称为[Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837v1),是目标检测任务的一种蒸馏方法,FGD蒸馏分为两个部分`Focal`和`Global`。`Focal`蒸馏分离图像的前景和背景,让学生模型分别关注教师模型的前景和背景部分特征的关键像素;`Global`蒸馏部分重建不同像素之间的关系并将其从教师转移到学生,以补偿`Focal`蒸馏中丢失的全局信息。试验结果表明,FGD蒸馏算法在基于anchor和anchor free的方法上能有效提升模型精度。 +在PaddleDetection中,我们实现了FGD算法,并基于retinaNet算法进行验证,实验结果如下: +| algorithm | model | AP | download| +|:-:| :-: | :-: | :-:| +|retinaNet_r101_fpn_2x | teacher | 40.6 | [download](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) | +|retinaNet_r50_fpn_1x| student | 37.5 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_1x_coco.pdparams) | +|retinaNet_r50_fpn_2x + FGD| student | 40.8 |[download](https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams) | + + + ## Citations ``` @article{mehta2018object, @@ -15,4 +28,12 @@ COCO数据集作为目标检测任务的训练目标难度更大,意味着teac archivePrefix={arXiv}, primaryClass={cs.CV} } + +@inproceedings{yang2022focal, + title={Focal and global knowledge distillation for detectors}, + author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={4643--4652}, + year={2022} +} ``` diff --git a/configs/slim/distill/retinanet_resnet101_coco_distill.yml b/configs/slim/distill/retinanet_resnet101_coco_distill.yml new file mode 100644 index 0000000000000000000000000000000000000000..f55b1b0d2590b1c124c4f3e53ed418538aaef75f --- /dev/null +++ b/configs/slim/distill/retinanet_resnet101_coco_distill.yml @@ -0,0 +1,18 @@ +_BASE_: [ + '../../retinanet/retinanet_r101_fpn_2x_coco.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams + +slim: Distill +slim_method: FGD +distill_loss: FGDFeatureLoss + +FGDFeatureLoss: + student_channels: 256 + teacher_channels: 256 + temp: 0.5 + alpha_fgd: 0.001 + beta_fgd: 0.0005 + gamma_fgd: 0.0005 + lambda_fgd: 0.000005 diff --git a/ppdet/optimizer/optimizer.py b/ppdet/optimizer/optimizer.py index fed47e85a85844c5070e67799d2d0c94c77e0d25..bcba182d3229c74206b8623b003d0e3e72160506 100644 --- a/ppdet/optimizer/optimizer.py +++ b/ppdet/optimizer/optimizer.py @@ -362,7 +362,8 @@ class OptimizerBuilder(): else: params = model.parameters() + train_params = [param for param in params if param.trainable is True] return op(learning_rate=learning_rate, - parameters=params, + parameters=train_params, grad_clip=grad_clip, **optim_args) diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index 5347d60467944f381b7c6d4514169dc82a0bd648..81ced2dd9f50534e1a3794bc2821d4313c5e91c2 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -35,7 +35,11 @@ def build_slim_model(cfg, slim_cfg, mode='train'): return cfg if slim_load_cfg['slim'] == 'Distill': - model = DistillModel(cfg, slim_cfg) + if "slim_method" in slim_load_cfg and slim_load_cfg[ + 'slim_method'] == "FGD": + model = FGDDistillModel(cfg, slim_cfg) + else: + model = DistillModel(cfg, slim_cfg) cfg['model'] = model cfg['slim_type'] = cfg.slim elif slim_load_cfg['slim'] == 'OFA': diff --git a/ppdet/slim/distill.py b/ppdet/slim/distill.py index b808553dd0c0b6a8285b8090385ac6e1cc4b8e69..a0ec9f7b0fad8334d38119d11b8b78e18cbbe133 100644 --- a/ppdet/slim/distill.py +++ b/ppdet/slim/distill.py @@ -19,6 +19,7 @@ from __future__ import print_function import paddle import paddle.nn as nn import paddle.nn.functional as F +from paddle import ParamAttr from ppdet.core.workspace import register, create, load_config from ppdet.modeling import ops @@ -63,6 +64,95 @@ class DistillModel(nn.Layer): return self.student_model(inputs) +class FGDDistillModel(nn.Layer): + """ + Build FGD distill model. + Args: + cfg: The student config. + slim_cfg: The teacher and distill config. + """ + + def __init__(self, cfg, slim_cfg): + super(FGDDistillModel, self).__init__() + self.student_cfg = cfg + slim_cfg = load_config(slim_cfg) + self.teacher_cfg = slim_cfg + self.loss_cfg = slim_cfg + self.is_loaded_weights = True + self.is_inherit = True + + self.student_model = create(self.student_cfg.architecture) + + self.teacher_model = create(self.teacher_cfg.architecture) + self.teacher_model.eval() + + for param in self.teacher_model.parameters(): + param.trainable = False + + if 'pretrain_weights' in self.student_cfg and self.student_cfg.pretrain_weights: + if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: + self._load_pretrain_weights(self.student_model, + self.teacher_cfg.pretrain_weights) + print("loading teacher weights to student model!") + + self._load_pretrain_weights(self.student_model, + self.student_cfg.pretrain_weights) + print("loading student model Done") + + if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: + self._load_pretrain_weights(self.teacher_model, + self.teacher_cfg.pretrain_weights) + print("loading teacher model Done") + + self.fgd_loss_dic = self.build_loss(self.loss_cfg.distill_loss) + + def _load_pretrain_weights(self, model, weights): + if self.is_loaded_weights: + return + self.start_epoch = 0 + load_pretrain_weight(model, weights) + logger.debug("Load weights {} to start training".format(weights)) + + def build_loss(self, + cfg, + name_list=[ + 'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', + 'neck_f_0' + ]): + loss_func = dict() + for idx, k in enumerate(name_list): + loss_func[k] = create(cfg) + return loss_func + + def forward(self, inputs): + if self.training: + s_body_feats = self.student_model.backbone(inputs) + s_neck_feats = self.student_model.neck(s_body_feats) + + with paddle.no_grad(): + t_body_feats = self.teacher_model.backbone(inputs) + t_neck_feats = self.teacher_model.neck(t_body_feats) + + loss_dict = {} + for idx, k in enumerate(self.fgd_loss_dic): + loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx], + t_neck_feats[idx], inputs) + + loss = self.student_model.head(s_neck_feats, inputs) + for k in loss_dict: + loss['loss'] += loss_dict[k] + loss[k] = loss_dict[k] + return loss + + else: + body_feats = self.student_model.backbone(inputs) + neck_feats = self.student_model.neck(body_feats) + head_outs = self.student_model.head(neck_feats) + bbox, bbox_num = self.student_model.head.post_process( + head_outs, inputs['im_shape'], inputs['scale_factor']) + return {'bbox': bbox, 'bbox_num': bbox_num} + + @register class DistillYOLOv3Loss(nn.Layer): def __init__(self, weight=1000): @@ -107,3 +197,254 @@ class DistillYOLOv3Loss(nn.Layer): loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss ) * self.weight return loss + + +def parameter_init(mode="kaiming", value=0.): + if mode == "kaiming": + weight_attr = paddle.nn.initializer.KaimingUniform() + elif mode == "constant": + weight_attr = paddle.nn.initializer.Constant(value=value) + else: + weight_attr = paddle.nn.initializer.KaimingUniform() + + weight_init = ParamAttr(initializer=weight_attr) + return weight_init + + +@register +class FGDFeatureLoss(nn.Layer): + """ + The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py + Paddle version of `Focal and Global Knowledge Distillation for Detectors` + + Args: + student_channels(int): The number of channels in the student's FPN feature map. Default to 256. + teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. + temp (float, optional): The temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 + beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 + gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 + lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 + """ + + def __init__(self, + student_channels=256, + teacher_channels=256, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005): + super(FGDFeatureLoss, self).__init__() + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + + kaiming_init = parameter_init("kaiming") + zeros_init = parameter_init("constant", 0.0) + + if student_channels != teacher_channels: + self.align = nn.Conv2d( + student_channels, + teacher_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=kaiming_init) + student_channels = teacher_channels + else: + self.align = None + + self.conv_mask_s = nn.Conv2D( + student_channels, 1, kernel_size=1, weight_attr=kaiming_init) + self.conv_mask_t = nn.Conv2D( + teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) + + self.stu_conv_block = nn.Sequential( + nn.Conv2D( + student_channels, + student_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([student_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + student_channels // 2, + student_channels, + kernel_size=1, + weight_attr=zeros_init)) + self.tea_conv_block = nn.Sequential( + nn.Conv2D( + teacher_channels, + teacher_channels // 2, + kernel_size=1, + weight_attr=zeros_init), + nn.LayerNorm([teacher_channels // 2, 1, 1]), + nn.ReLU(), + nn.Conv2D( + teacher_channels // 2, + teacher_channels, + kernel_size=1, + weight_attr=zeros_init)) + + def spatial_channel_attention(self, x, t=0.5): + shape = paddle.shape(x) + N, C, H, W = shape + + _f = paddle.abs(x) + spatial_map = paddle.reshape( + paddle.mean( + _f, axis=1, keepdim=True) / t, [N, -1]) + spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W + spatial_att = paddle.reshape(spatial_map, [N, H, W]) + + channel_map = paddle.mean( + paddle.mean( + _f, axis=2, keepdim=False), axis=2, keepdim=False) + channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C + return [spatial_att, channel_att] + + def spatial_pool(self, x, mode="teacher"): + batch, channel, width, height = x.shape + x_copy = x + x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) + x_copy = x_copy.unsqueeze(1) + if mode.lower() == "student": + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + + context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) + context_mask = F.softmax(context_mask, axis=2) + context_mask = context_mask.unsqueeze(-1) + context = paddle.matmul(x_copy, context_mask) + context = paddle.reshape(context, [batch, channel, 1, 1]) + + return context + + def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, + tea_spatial_att): + def _func(a, b): + return paddle.sum(paddle.abs(a - b)) / len(a) + + mask_loss = _func(stu_channel_att, tea_channel_att) + _func( + stu_spatial_att, tea_spatial_att) + + return mask_loss + + def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, + tea_channel_att, tea_spatial_att): + + Mask_fg = Mask_fg.unsqueeze(axis=1) + Mask_bg = Mask_bg.unsqueeze(axis=1) + + tea_channel_att = tea_channel_att.unsqueeze(axis=-1) + tea_channel_att = tea_channel_att.unsqueeze(axis=-1) + + tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) + + fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) + fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) + fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) + bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) + + fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) + fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) + fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) + bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) + + fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg) + bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg) + + return fg_loss, bg_loss + + def relation_loss(self, stu_feature, tea_feature): + context_s = self.spatial_pool(stu_feature, "student") + context_t = self.spatial_pool(tea_feature, "teacher") + + out_s = stu_feature + self.stu_conv_block(context_s) + out_t = tea_feature + self.tea_conv_block(context_t) + + rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) + + return rela_loss + + def mask_value(self, mask, xl, xr, yl, yr, value): + mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) + return mask + + def forward(self, stu_feature, tea_feature, inputs): + """Forward function. + Args: + stu_feature(Tensor): Bs*C*H*W, student's feature map + tea_feature(Tensor): Bs*C*H*W, teacher's feature map + inputs: The inputs with gt bbox and input shape info. + """ + assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ + f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' + assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( + ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." + gt_bboxes = inputs['gt_bbox'] + ins_shape = [ + inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) + ] + + if self.align is not None: + stu_feature = self.align(stu_feature) + + N, C, H, W = stu_feature.shape + + tea_spatial_att, tea_channel_att = self.spatial_channel_attention( + tea_feature, self.temp) + stu_spatial_att, stu_channel_att = self.spatial_channel_attention( + stu_feature, self.temp) + + Mask_fg = paddle.zeros(tea_spatial_att.shape) + Mask_bg = paddle.ones_like(tea_spatial_att) + one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) + zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) + wmin, wmax, hmin, hmax, area = [], [], [], [], [] + + for i in range(N): + tmp_box = paddle.ones_like(gt_bboxes[i]) + tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W + tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W + tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H + tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H + + zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") + ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") + wmin.append( + paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) + wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) + hmin.append( + paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) + hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) + + area_recip = 1.0 / ( + hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( + wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) + + for j in range(len(gt_bboxes[i])): + Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], + hmax[i][j] + 1, wmin[i][j], + wmax[i][j] + 1, area_recip[0][j]) + + Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) + + if paddle.sum(Mask_bg[i]): + Mask_bg[i] /= paddle.sum(Mask_bg[i]) + + fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, + Mask_bg, tea_channel_att, + tea_spatial_att) + mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, + stu_spatial_att, tea_spatial_att) + rela_loss = self.relation_loss(stu_feature, tea_feature) + + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + + return loss