diff --git a/ppdet/slim/distill.py b/ppdet/slim/distill.py index 11df132d63772d5bd01fe5d8f647cdcfdc3d1a98..8f7ee7412271bf98fc8e777f1289f897aca6fd2a 100644 --- a/ppdet/slim/distill.py +++ b/ppdet/slim/distill.py @@ -234,26 +234,23 @@ class FGDFeatureLoss(nn.Layer): Paddle version of `Focal and Global Knowledge Distillation for Detectors` Args: - student_channels(int): Number of channels in the student's feature map. - teacher_channels(int): Number of channels in the teacher's feature map. - temp (float, optional): Temperature coefficient. Defaults to 0.5. - name (str): the loss name of the layer - alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001 - beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005 - gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001 - lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005 + student_channels(int): The number of channels in the student's FPN feature map. Default to 256. + teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. + temp (float, optional): The temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 + beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 + gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 + lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 """ - def __init__( - self, - student_channels, - teacher_channels, - name=None, - temp=0.5, - alpha_fgd=0.001, - beta_fgd=0.0005, - gamma_fgd=0.001, - lambda_fgd=0.000005): + def __init__(self, + student_channels=256, + teacher_channels=256, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005): super(FGDFeatureLoss, self).__init__() self.temp = temp self.alpha_fgd = alpha_fgd @@ -272,27 +269,29 @@ class FGDFeatureLoss(nn.Layer): stride=1, padding=0, weight_attr=kaiming_init) + student_channels = teacher_channels else: self.align = None self.conv_mask_s = nn.Conv2D( - teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) + student_channels, 1, kernel_size=1, weight_attr=kaiming_init) self.conv_mask_t = nn.Conv2D( teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) - self.channel_add_conv_s = nn.Sequential( + + self.stu_conv_block = nn.Sequential( nn.Conv2D( - teacher_channels, - teacher_channels // 2, + student_channels, + student_channels // 2, kernel_size=1, weight_attr=zeros_init), - nn.LayerNorm([teacher_channels // 2, 1, 1]), + nn.LayerNorm([student_channels // 2, 1, 1]), nn.ReLU(), nn.Conv2D( - teacher_channels // 2, - teacher_channels, + student_channels // 2, + student_channels, kernel_size=1, weight_attr=zeros_init)) - self.channel_add_conv_t = nn.Sequential( + self.tea_conv_block = nn.Sequential( nn.Conv2D( teacher_channels, teacher_channels // 2, @@ -306,72 +305,69 @@ class FGDFeatureLoss(nn.Layer): kernel_size=1, weight_attr=zeros_init)) - def gc_block(self, feature, t=0.5): - """ - """ - shape = paddle.shape(feature) + def spatial_channel_attention(self, x, t=0.5): + shape = paddle.shape(x) N, C, H, W = shape - _f = paddle.abs(feature) - s_map = paddle.reshape( + _f = paddle.abs(x) + spatial_map = paddle.reshape( paddle.mean( _f, axis=1, keepdim=True) / t, [N, -1]) - s_map = F.softmax(s_map, axis=1, dtype="float32") * H * W - s_attention = paddle.reshape(s_map, [N, H, W]) + spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W + spatial_att = paddle.reshape(spatial_map, [N, H, W]) - c_map = paddle.mean( + channel_map = paddle.mean( paddle.mean( _f, axis=2, keepdim=False), axis=2, keepdim=False) - c_attention = F.softmax(c_map / t, axis=1, dtype="float32") * C - return s_attention, c_attention + channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C + return [spatial_att, channel_att] - def spatial_pool(self, x, in_type): + def spatial_pool(self, x, mode="teacher"): batch, channel, width, height = x.shape - input_x = x - # [N, C, H * W] - input_x = paddle.reshape(input_x, [batch, channel, height * width]) - # [N, 1, C, H * W] - input_x = input_x.unsqueeze(1) - # [N, 1, H, W] - if in_type == 0: + x_copy = x + x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) + x_copy = x_copy.unsqueeze(1) + if mode.lower() == "student": context_mask = self.conv_mask_s(x) else: context_mask = self.conv_mask_t(x) - # [N, 1, H * W] + context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) - # [N, 1, H * W] context_mask = F.softmax(context_mask, axis=2) - # [N, 1, H * W, 1] context_mask = context_mask.unsqueeze(-1) - # [N, 1, C, 1] - context = paddle.matmul(input_x, context_mask) - # [N, C, 1, 1] + context = paddle.matmul(x_copy, context_mask) context = paddle.reshape(context, [batch, channel, 1, 1]) return context - def get_mask_loss(self, C_s, C_t, S_s, S_t): - mask_loss = paddle.sum(paddle.abs((C_s - C_t))) / len(C_s) + paddle.sum( - paddle.abs((S_s - S_t))) / len(S_s) + def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, + tea_spatial_att): + def _func(a, b): + return paddle.sum(paddle.abs(a - b)) / len(a) + + mask_loss = _func(stu_channel_att, tea_channel_att) + _func( + stu_spatial_att, tea_spatial_att) + return mask_loss - def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, - S_t): + def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, + tea_channel_att, tea_spatial_att): + Mask_fg = Mask_fg.unsqueeze(axis=1) Mask_bg = Mask_bg.unsqueeze(axis=1) - C_t = C_t.unsqueeze(axis=-1) - C_t = C_t.unsqueeze(axis=-1) + tea_channel_att = tea_channel_att.unsqueeze(axis=-1) + tea_channel_att = tea_channel_att.unsqueeze(axis=-1) - S_t = S_t.unsqueeze(axis=1) + tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) - fea_t = paddle.multiply(preds_T, paddle.sqrt(S_t)) - fea_t = paddle.multiply(fea_t, paddle.sqrt(C_t)) + fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) + fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) - fea_s = paddle.multiply(preds_S, paddle.sqrt(S_t)) - fea_s = paddle.multiply(fea_s, paddle.sqrt(C_t)) + fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) + fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) @@ -380,18 +376,12 @@ class FGDFeatureLoss(nn.Layer): return fg_loss, bg_loss - def get_rela_loss(self, preds_S, preds_T): - context_s = self.spatial_pool(preds_S, 0) - context_t = self.spatial_pool(preds_T, 1) + def relation_loss(self, stu_feature, tea_feature): + context_s = self.spatial_pool(stu_feature, "student") + context_t = self.spatial_pool(tea_feature, "teacher") - out_s = preds_S - out_t = preds_T - - channel_add_s = self.channel_add_conv_s(context_s) - out_s = out_s + channel_add_s - - channel_add_t = self.channel_add_conv_t(context_t) - out_t = out_t + channel_add_t + out_s = stu_feature + self.stu_conv_block(context_s) + out_t = tea_feature + self.tea_conv_block(context_t) rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) @@ -401,90 +391,98 @@ class FGDFeatureLoss(nn.Layer): mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) return mask - def forward(self, preds_S, preds_T, inputs): + def forward(self, stu_feature, tea_feature, inputs): """Forward function. Args: - preds_S(Tensor): Bs*C*H*W, student's feature map - preds_T(Tensor): Bs*C*H*W, teacher's feature map + stu_feature(Tensor): Bs*C*H*W, student's feature map + tea_feature(Tensor): Bs*C*H*W, teacher's feature map inputs: The inputs with gt bbox and input shape info. """ - assert preds_S.shape[-2:] == preds_T.shape[-2:], \ - f'The shape of Student feature {preds_S.shape} and Teacher feature {preds_T.shape} should be the same.' + assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ + f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' + assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys( + ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs." gt_bboxes = inputs['gt_bbox'] - assert len(gt_bboxes) == preds_S.shape[0], "error" + ins_shape = [ + inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0]) + ] - # select index index_gt = [] for i in range(len(gt_bboxes)): if gt_bboxes[i].size > 2: index_gt.append(i) - index_gt_t = paddle.to_tensor(index_gt) # to tensor - preds_S = paddle.index_select(preds_S, index_gt_t) - preds_T = paddle.index_select(preds_T, index_gt_t) - assert preds_S.shape == preds_T.shape, "error" - - img_metas_tmp = [{ - 'img_shape': inputs['im_shape'][i] - } for i in range(inputs['im_shape'].shape[0])] - img_metas = [img_metas_tmp[c] for c in index_gt] - gt_bboxes = [gt_bboxes[c] for c in index_gt] - assert len(gt_bboxes) == len(img_metas), "error" + # only distill feature with labeled GTbox + if len(index_gt) != len(gt_bboxes): + index_gt_t = paddle.to_tensor(index_gt) + preds_S = paddle.index_select(preds_S, index_gt_t) + preds_T = paddle.index_select(preds_T, index_gt_t) - assert len(gt_bboxes) == preds_T.shape[0], "error" + ins_shape = [ins_shape[c] for c in index_gt] + gt_bboxes = [gt_bboxes[c] for c in index_gt] + assert len(gt_bboxes) == preds_T.shape[ + 0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]." if self.align is not None: - preds_S = self.align(preds_S) + stu_feature = self.align(stu_feature) - N, C, H, W = preds_S.shape + N, C, H, W = stu_feature.shape - S_attention_t, C_attention_t = self.gc_block(preds_T, self.temp) - S_attention_s, C_attention_s = self.gc_block(preds_S, self.temp) + tea_spatial_att, tea_channel_att = self.spatial_channel_attention( + tea_feature, self.temp) + stu_spatial_att, stu_channel_att = self.spatial_channel_attention( + stu_feature, self.temp) + + Mask_fg = paddle.zeros(tea_spatial_att.shape) + Mask_bg = paddle.ones_like(tea_spatial_att) + one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) + zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) + mask_fg.stop_gradient = True + Mask_bg.stop_gradient = True + one_tmp.stop_gradient = True + zero_tmp.stop_gradient = True - Mask_fg = paddle.zeros(S_attention_t.shape) - Mask_bg = paddle.ones_like(S_attention_t) - one_tmp = paddle.ones([*S_attention_t.shape[1:]]) - zero_tmp = paddle.zeros([*S_attention_t.shape[1:]]) wmin, wmax, hmin, hmax, area = [], [], [], [], [] + for i in range(N): - new_boxxes = paddle.ones_like(gt_bboxes[i]) - new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][ - 1] * W - new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][ - 1] * W - new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][ - 0] * H - new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][ - 0] * H - zero = paddle.zeros_like(new_boxxes[:, 0], dtype="int32") - ones = paddle.ones_like(new_boxxes[:, 2], dtype="int32") + tmp_box = paddle.ones_like(gt_bboxes[i]) + tmp_box.stop_gradient = True + tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W + tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W + tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H + tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H + + zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") + ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") + zero.stop_gradient = True + ones.stop_gradient = True + wmin.append( - paddle.cast(paddle.floor(new_boxxes[:, 0]), "int32").maximum( - zero)) - wmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 2]), "int32")) + paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) + wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) hmin.append( - paddle.cast(paddle.floor(new_boxxes[:, 1]), "int32").maximum( - zero)) - hmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 3]), "int32")) + paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) + hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) - area = 1.0 / ( + area_recip = 1.0 / ( hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) + for j in range(len(gt_bboxes[i])): Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], hmax[i][j] + 1, wmin[i][j], - wmax[i][j] + 1, area[0][j]) + wmax[i][j] + 1, area_recip[0][j]) + Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) if paddle.sum(Mask_bg[i]): Mask_bg[i] /= paddle.sum(Mask_bg[i]) - fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, - C_attention_s, C_attention_t, - S_attention_s, S_attention_t) - mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, - S_attention_s, S_attention_t) - rela_loss = self.get_rela_loss(preds_S, preds_T) - + fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, + Mask_bg, tea_channel_att, + tea_spatial_att) + mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, + stu_spatial_att, tea_spatial_att) + rela_loss = self.relation_loss(stu_feature, tea_feature) loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss