From 1d867e8222a0ced695bb4c727c95457a67b277b9 Mon Sep 17 00:00:00 2001 From: user1018 <614803115@qq.com> Date: Fri, 19 Aug 2022 16:55:46 +0800 Subject: [PATCH] Update distill (#6693) --- ppdet/slim/distill.py | 257 ++++++++++++++++++++++++------------------ 1 file changed, 148 insertions(+), 109 deletions(-) diff --git a/ppdet/slim/distill.py b/ppdet/slim/distill.py index 808713ffe..11df132d6 100644 --- a/ppdet/slim/distill.py +++ b/ppdet/slim/distill.py @@ -135,7 +135,13 @@ class FGDDistillModel(nn.Layer): if self.arch == "RetinaNet": loss = self.student_model.head(s_neck_feats, inputs) elif self.arch == "PicoDet": - loss = self.student_model.get_loss() + head_outs = self.student_model.head( + s_neck_feats, self.student_model.export_post_process) + loss_gfl = self.student_model.head.get_loss(head_outs, inputs) + total_loss = paddle.add_n(list(loss_gfl.values())) + loss = {} + loss.update(loss_gfl) + loss.update({'loss': total_loss}) else: raise ValueError(f"Unsupported model {self.arch}") for k in loss_dict: @@ -151,7 +157,14 @@ class FGDDistillModel(nn.Layer): head_outs, inputs['im_shape'], inputs['scale_factor']) return {'bbox': bbox, 'bbox_num': bbox_num} elif self.arch == "PicoDet": - return self.student_model.head.get_pred() + head_outs = self.student_model.head( + neck_feats, self.student_model.export_post_process) + scale_factor = inputs['scale_factor'] + bboxes, bbox_num = self.student_model.head.post_process( + head_outs, + scale_factor, + export_nms=self.student_model.export_nms) + return {'bbox': bboxes, 'bbox_num': bbox_num} else: raise ValueError(f"Unsupported model {self.arch}") @@ -221,23 +234,26 @@ class FGDFeatureLoss(nn.Layer): Paddle version of `Focal and Global Knowledge Distillation for Detectors` Args: - student_channels(int): The number of channels in the student's FPN feature map. Default to 256. - teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. - temp (float, optional): The temperature coefficient. Defaults to 0.5. - alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 - beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 - gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 - lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 + student_channels(int): Number of channels in the student's feature map. + teacher_channels(int): Number of channels in the teacher's feature map. + temp (float, optional): Temperature coefficient. Defaults to 0.5. + name (str): the loss name of the layer + alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001 + beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005 + gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001 + lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005 """ - def __init__(self, - student_channels=256, - teacher_channels=256, - temp=0.5, - alpha_fgd=0.001, - beta_fgd=0.0005, - gamma_fgd=0.001, - lambda_fgd=0.000005): + def __init__( + self, + student_channels, + teacher_channels, + name=None, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005): super(FGDFeatureLoss, self).__init__() self.temp = temp self.alpha_fgd = alpha_fgd @@ -256,29 +272,27 @@ class FGDFeatureLoss(nn.Layer): stride=1, padding=0, weight_attr=kaiming_init) - student_channels = teacher_channels else: self.align = None self.conv_mask_s = nn.Conv2D( - student_channels, 1, kernel_size=1, weight_attr=kaiming_init) + teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) self.conv_mask_t = nn.Conv2D( teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) - - self.stu_conv_block = nn.Sequential( + self.channel_add_conv_s = nn.Sequential( nn.Conv2D( - student_channels, - student_channels // 2, + teacher_channels, + teacher_channels // 2, kernel_size=1, weight_attr=zeros_init), - nn.LayerNorm([student_channels // 2, 1, 1]), + nn.LayerNorm([teacher_channels // 2, 1, 1]), nn.ReLU(), nn.Conv2D( - student_channels // 2, - student_channels, + teacher_channels // 2, + teacher_channels, kernel_size=1, weight_attr=zeros_init)) - self.tea_conv_block = nn.Sequential( + self.channel_add_conv_t = nn.Sequential( nn.Conv2D( teacher_channels, teacher_channels // 2, @@ -292,69 +306,72 @@ class FGDFeatureLoss(nn.Layer): kernel_size=1, weight_attr=zeros_init)) - def spatial_channel_attention(self, x, t=0.5): - shape = paddle.shape(x) + def gc_block(self, feature, t=0.5): + """ + """ + shape = paddle.shape(feature) N, C, H, W = shape - _f = paddle.abs(x) - spatial_map = paddle.reshape( + _f = paddle.abs(feature) + s_map = paddle.reshape( paddle.mean( _f, axis=1, keepdim=True) / t, [N, -1]) - spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W - spatial_att = paddle.reshape(spatial_map, [N, H, W]) + s_map = F.softmax(s_map, axis=1, dtype="float32") * H * W + s_attention = paddle.reshape(s_map, [N, H, W]) - channel_map = paddle.mean( + c_map = paddle.mean( paddle.mean( _f, axis=2, keepdim=False), axis=2, keepdim=False) - channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C - return [spatial_att, channel_att] + c_attention = F.softmax(c_map / t, axis=1, dtype="float32") * C + return s_attention, c_attention - def spatial_pool(self, x, mode="teacher"): + def spatial_pool(self, x, in_type): batch, channel, width, height = x.shape - x_copy = x - x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) - x_copy = x_copy.unsqueeze(1) - if mode.lower() == "student": + input_x = x + # [N, C, H * W] + input_x = paddle.reshape(input_x, [batch, channel, height * width]) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + if in_type == 0: context_mask = self.conv_mask_s(x) else: context_mask = self.conv_mask_t(x) - + # [N, 1, H * W] context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) + # [N, 1, H * W] context_mask = F.softmax(context_mask, axis=2) + # [N, 1, H * W, 1] context_mask = context_mask.unsqueeze(-1) - context = paddle.matmul(x_copy, context_mask) + # [N, 1, C, 1] + context = paddle.matmul(input_x, context_mask) + # [N, C, 1, 1] context = paddle.reshape(context, [batch, channel, 1, 1]) return context - def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, - tea_spatial_att): - def _func(a, b): - return paddle.sum(paddle.abs(a - b)) / len(a) - - mask_loss = _func(stu_channel_att, tea_channel_att) + _func( - stu_spatial_att, tea_spatial_att) - + def get_mask_loss(self, C_s, C_t, S_s, S_t): + mask_loss = paddle.sum(paddle.abs((C_s - C_t))) / len(C_s) + paddle.sum( + paddle.abs((S_s - S_t))) / len(S_s) return mask_loss - def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, - tea_channel_att, tea_spatial_att): - + def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, + S_t): Mask_fg = Mask_fg.unsqueeze(axis=1) Mask_bg = Mask_bg.unsqueeze(axis=1) - tea_channel_att = tea_channel_att.unsqueeze(axis=-1) - tea_channel_att = tea_channel_att.unsqueeze(axis=-1) + C_t = C_t.unsqueeze(axis=-1) + C_t = C_t.unsqueeze(axis=-1) - tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) + S_t = S_t.unsqueeze(axis=1) - fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) - fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) + fea_t = paddle.multiply(preds_T, paddle.sqrt(S_t)) + fea_t = paddle.multiply(fea_t, paddle.sqrt(C_t)) fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) - fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) - fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) + fea_s = paddle.multiply(preds_S, paddle.sqrt(S_t)) + fea_s = paddle.multiply(fea_s, paddle.sqrt(C_t)) fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) @@ -363,12 +380,18 @@ class FGDFeatureLoss(nn.Layer): return fg_loss, bg_loss - def relation_loss(self, stu_feature, tea_feature): - context_s = self.spatial_pool(stu_feature, "student") - context_t = self.spatial_pool(tea_feature, "teacher") + def get_rela_loss(self, preds_S, preds_T): + context_s = self.spatial_pool(preds_S, 0) + context_t = self.spatial_pool(preds_T, 1) + + out_s = preds_S + out_t = preds_T - out_s = stu_feature + self.stu_conv_block(context_s) - out_t = tea_feature + self.tea_conv_block(context_t) + channel_add_s = self.channel_add_conv_s(context_s) + out_s = out_s + channel_add_s + + channel_add_t = self.channel_add_conv_t(context_t) + out_t = out_t + channel_add_t rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) @@ -378,74 +401,90 @@ class FGDFeatureLoss(nn.Layer): mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) return mask - def forward(self, stu_feature, tea_feature, inputs): + def forward(self, preds_S, preds_T, inputs): """Forward function. Args: - stu_feature(Tensor): Bs*C*H*W, student's feature map - tea_feature(Tensor): Bs*C*H*W, teacher's feature map + preds_S(Tensor): Bs*C*H*W, student's feature map + preds_T(Tensor): Bs*C*H*W, teacher's feature map inputs: The inputs with gt bbox and input shape info. """ - assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ - f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' - assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( - ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." + assert preds_S.shape[-2:] == preds_T.shape[-2:], \ + f'The shape of Student feature {preds_S.shape} and Teacher feature {preds_T.shape} should be the same.' gt_bboxes = inputs['gt_bbox'] - ins_shape = [ - inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) - ] + assert len(gt_bboxes) == preds_S.shape[0], "error" + + # select index + index_gt = [] + for i in range(len(gt_bboxes)): + if gt_bboxes[i].size > 2: + index_gt.append(i) + index_gt_t = paddle.to_tensor(index_gt) # to tensor + preds_S = paddle.index_select(preds_S, index_gt_t) + preds_T = paddle.index_select(preds_T, index_gt_t) + assert preds_S.shape == preds_T.shape, "error" + + img_metas_tmp = [{ + 'img_shape': inputs['im_shape'][i] + } for i in range(inputs['im_shape'].shape[0])] + img_metas = [img_metas_tmp[c] for c in index_gt] + gt_bboxes = [gt_bboxes[c] for c in index_gt] + assert len(gt_bboxes) == len(img_metas), "error" + + assert len(gt_bboxes) == preds_T.shape[0], "error" if self.align is not None: - stu_feature = self.align(stu_feature) + preds_S = self.align(preds_S) - N, C, H, W = stu_feature.shape + N, C, H, W = preds_S.shape - tea_spatial_att, tea_channel_att = self.spatial_channel_attention( - tea_feature, self.temp) - stu_spatial_att, stu_channel_att = self.spatial_channel_attention( - stu_feature, self.temp) + S_attention_t, C_attention_t = self.gc_block(preds_T, self.temp) + S_attention_s, C_attention_s = self.gc_block(preds_S, self.temp) - Mask_fg = paddle.zeros(tea_spatial_att.shape) - Mask_bg = paddle.ones_like(tea_spatial_att) - one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) - zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) + Mask_fg = paddle.zeros(S_attention_t.shape) + Mask_bg = paddle.ones_like(S_attention_t) + one_tmp = paddle.ones([*S_attention_t.shape[1:]]) + zero_tmp = paddle.zeros([*S_attention_t.shape[1:]]) wmin, wmax, hmin, hmax, area = [], [], [], [], [] - for i in range(N): - tmp_box = paddle.ones_like(gt_bboxes[i]) - tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W - tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W - tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H - tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H - - zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") - ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") + new_boxxes = paddle.ones_like(gt_bboxes[i]) + new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][ + 1] * W + new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][ + 1] * W + new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][ + 0] * H + new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][ + 0] * H + zero = paddle.zeros_like(new_boxxes[:, 0], dtype="int32") + ones = paddle.ones_like(new_boxxes[:, 2], dtype="int32") wmin.append( - paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) - wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) + paddle.cast(paddle.floor(new_boxxes[:, 0]), "int32").maximum( + zero)) + wmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 2]), "int32")) hmin.append( - paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) - hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) + paddle.cast(paddle.floor(new_boxxes[:, 1]), "int32").maximum( + zero)) + hmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 3]), "int32")) - area_recip = 1.0 / ( + area = 1.0 / ( hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) - for j in range(len(gt_bboxes[i])): Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], - wmax[i][j] + 1, area_recip[0][j]) - + wmax[i][j] + 1, area[0][j]) Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) if paddle.sum(Mask_bg[i]): Mask_bg[i] /= paddle.sum(Mask_bg[i]) - fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, - Mask_bg, tea_channel_att, - tea_spatial_att) - mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, - stu_spatial_att, tea_spatial_att) - rela_loss = self.relation_loss(stu_feature, tea_feature) + fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, + C_attention_s, C_attention_t, + S_attention_s, S_attention_t) + mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, + S_attention_s, S_attention_t) + rela_loss = self.get_rela_loss(preds_S, preds_T) + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss -- GitLab