未验证 提交 1d867e82 编写于 作者: U user1018 提交者: GitHub

Update distill (#6693)

上级 20037620
...@@ -135,7 +135,13 @@ class FGDDistillModel(nn.Layer): ...@@ -135,7 +135,13 @@ class FGDDistillModel(nn.Layer):
if self.arch == "RetinaNet": if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs) loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet": elif self.arch == "PicoDet":
loss = self.student_model.get_loss() head_outs = self.student_model.head(
s_neck_feats, self.student_model.export_post_process)
loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
total_loss = paddle.add_n(list(loss_gfl.values()))
loss = {}
loss.update(loss_gfl)
loss.update({'loss': total_loss})
else: else:
raise ValueError(f"Unsupported model {self.arch}") raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict: for k in loss_dict:
...@@ -151,7 +157,14 @@ class FGDDistillModel(nn.Layer): ...@@ -151,7 +157,14 @@ class FGDDistillModel(nn.Layer):
head_outs, inputs['im_shape'], inputs['scale_factor']) head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num} return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet": elif self.arch == "PicoDet":
return self.student_model.head.get_pred() head_outs = self.student_model.head(
neck_feats, self.student_model.export_post_process)
scale_factor = inputs['scale_factor']
bboxes, bbox_num = self.student_model.head.post_process(
head_outs,
scale_factor,
export_nms=self.student_model.export_nms)
return {'bbox': bboxes, 'bbox_num': bbox_num}
else: else:
raise ValueError(f"Unsupported model {self.arch}") raise ValueError(f"Unsupported model {self.arch}")
...@@ -221,18 +234,21 @@ class FGDFeatureLoss(nn.Layer): ...@@ -221,18 +234,21 @@ class FGDFeatureLoss(nn.Layer):
Paddle version of `Focal and Global Knowledge Distillation for Detectors` Paddle version of `Focal and Global Knowledge Distillation for Detectors`
Args: Args:
student_channels(int): The number of channels in the student's FPN feature map. Default to 256. student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256. teacher_channels(int): Number of channels in the teacher's feature map.
temp (float, optional): The temperature coefficient. Defaults to 0.5. temp (float, optional): Temperature coefficient. Defaults to 0.5.
alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001 name (str): the loss name of the layer
beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005 alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001 beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005 gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
""" """
def __init__(self, def __init__(
student_channels=256, self,
teacher_channels=256, student_channels,
teacher_channels,
name=None,
temp=0.5, temp=0.5,
alpha_fgd=0.001, alpha_fgd=0.001,
beta_fgd=0.0005, beta_fgd=0.0005,
...@@ -256,29 +272,27 @@ class FGDFeatureLoss(nn.Layer): ...@@ -256,29 +272,27 @@ class FGDFeatureLoss(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
weight_attr=kaiming_init) weight_attr=kaiming_init)
student_channels = teacher_channels
else: else:
self.align = None self.align = None
self.conv_mask_s = nn.Conv2D( self.conv_mask_s = nn.Conv2D(
student_channels, 1, kernel_size=1, weight_attr=kaiming_init) teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.conv_mask_t = nn.Conv2D( self.conv_mask_t = nn.Conv2D(
teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init) teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
self.channel_add_conv_s = nn.Sequential(
self.stu_conv_block = nn.Sequential(
nn.Conv2D( nn.Conv2D(
student_channels, teacher_channels,
student_channels // 2, teacher_channels // 2,
kernel_size=1, kernel_size=1,
weight_attr=zeros_init), weight_attr=zeros_init),
nn.LayerNorm([student_channels // 2, 1, 1]), nn.LayerNorm([teacher_channels // 2, 1, 1]),
nn.ReLU(), nn.ReLU(),
nn.Conv2D( nn.Conv2D(
student_channels // 2, teacher_channels // 2,
student_channels, teacher_channels,
kernel_size=1, kernel_size=1,
weight_attr=zeros_init)) weight_attr=zeros_init))
self.tea_conv_block = nn.Sequential( self.channel_add_conv_t = nn.Sequential(
nn.Conv2D( nn.Conv2D(
teacher_channels, teacher_channels,
teacher_channels // 2, teacher_channels // 2,
...@@ -292,69 +306,72 @@ class FGDFeatureLoss(nn.Layer): ...@@ -292,69 +306,72 @@ class FGDFeatureLoss(nn.Layer):
kernel_size=1, kernel_size=1,
weight_attr=zeros_init)) weight_attr=zeros_init))
def spatial_channel_attention(self, x, t=0.5): def gc_block(self, feature, t=0.5):
shape = paddle.shape(x) """
"""
shape = paddle.shape(feature)
N, C, H, W = shape N, C, H, W = shape
_f = paddle.abs(x) _f = paddle.abs(feature)
spatial_map = paddle.reshape( s_map = paddle.reshape(
paddle.mean( paddle.mean(
_f, axis=1, keepdim=True) / t, [N, -1]) _f, axis=1, keepdim=True) / t, [N, -1])
spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W s_map = F.softmax(s_map, axis=1, dtype="float32") * H * W
spatial_att = paddle.reshape(spatial_map, [N, H, W]) s_attention = paddle.reshape(s_map, [N, H, W])
channel_map = paddle.mean( c_map = paddle.mean(
paddle.mean( paddle.mean(
_f, axis=2, keepdim=False), axis=2, keepdim=False) _f, axis=2, keepdim=False), axis=2, keepdim=False)
channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C c_attention = F.softmax(c_map / t, axis=1, dtype="float32") * C
return [spatial_att, channel_att] return s_attention, c_attention
def spatial_pool(self, x, mode="teacher"): def spatial_pool(self, x, in_type):
batch, channel, width, height = x.shape batch, channel, width, height = x.shape
x_copy = x input_x = x
x_copy = paddle.reshape(x_copy, [batch, channel, height * width]) # [N, C, H * W]
x_copy = x_copy.unsqueeze(1) input_x = paddle.reshape(input_x, [batch, channel, height * width])
if mode.lower() == "student": # [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
if in_type == 0:
context_mask = self.conv_mask_s(x) context_mask = self.conv_mask_s(x)
else: else:
context_mask = self.conv_mask_t(x) context_mask = self.conv_mask_t(x)
# [N, 1, H * W]
context_mask = paddle.reshape(context_mask, [batch, 1, height * width]) context_mask = paddle.reshape(context_mask, [batch, 1, height * width])
# [N, 1, H * W]
context_mask = F.softmax(context_mask, axis=2) context_mask = F.softmax(context_mask, axis=2)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1) context_mask = context_mask.unsqueeze(-1)
context = paddle.matmul(x_copy, context_mask) # [N, 1, C, 1]
context = paddle.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = paddle.reshape(context, [batch, channel, 1, 1]) context = paddle.reshape(context, [batch, channel, 1, 1])
return context return context
def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att, def get_mask_loss(self, C_s, C_t, S_s, S_t):
tea_spatial_att): mask_loss = paddle.sum(paddle.abs((C_s - C_t))) / len(C_s) + paddle.sum(
def _func(a, b): paddle.abs((S_s - S_t))) / len(S_s)
return paddle.sum(paddle.abs(a - b)) / len(a)
mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
stu_spatial_att, tea_spatial_att)
return mask_loss return mask_loss
def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg, def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s,
tea_channel_att, tea_spatial_att): S_t):
Mask_fg = Mask_fg.unsqueeze(axis=1) Mask_fg = Mask_fg.unsqueeze(axis=1)
Mask_bg = Mask_bg.unsqueeze(axis=1) Mask_bg = Mask_bg.unsqueeze(axis=1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1) C_t = C_t.unsqueeze(axis=-1)
tea_channel_att = tea_channel_att.unsqueeze(axis=-1) C_t = C_t.unsqueeze(axis=-1)
tea_spatial_att = tea_spatial_att.unsqueeze(axis=1) S_t = S_t.unsqueeze(axis=1)
fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att)) fea_t = paddle.multiply(preds_T, paddle.sqrt(S_t))
fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att)) fea_t = paddle.multiply(fea_t, paddle.sqrt(C_t))
fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg)) fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg)) bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))
fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att)) fea_s = paddle.multiply(preds_S, paddle.sqrt(S_t))
fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att)) fea_s = paddle.multiply(fea_s, paddle.sqrt(C_t))
fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg)) fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg)) bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))
...@@ -363,12 +380,18 @@ class FGDFeatureLoss(nn.Layer): ...@@ -363,12 +380,18 @@ class FGDFeatureLoss(nn.Layer):
return fg_loss, bg_loss return fg_loss, bg_loss
def relation_loss(self, stu_feature, tea_feature): def get_rela_loss(self, preds_S, preds_T):
context_s = self.spatial_pool(stu_feature, "student") context_s = self.spatial_pool(preds_S, 0)
context_t = self.spatial_pool(tea_feature, "teacher") context_t = self.spatial_pool(preds_T, 1)
out_s = preds_S
out_t = preds_T
out_s = stu_feature + self.stu_conv_block(context_s) channel_add_s = self.channel_add_conv_s(context_s)
out_t = tea_feature + self.tea_conv_block(context_t) out_s = out_s + channel_add_s
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t
rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s) rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
...@@ -378,74 +401,90 @@ class FGDFeatureLoss(nn.Layer): ...@@ -378,74 +401,90 @@ class FGDFeatureLoss(nn.Layer):
mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value) mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value)
return mask return mask
def forward(self, stu_feature, tea_feature, inputs): def forward(self, preds_S, preds_T, inputs):
"""Forward function. """Forward function.
Args: Args:
stu_feature(Tensor): Bs*C*H*W, student's feature map preds_S(Tensor): Bs*C*H*W, student's feature map
tea_feature(Tensor): Bs*C*H*W, teacher's feature map preds_T(Tensor): Bs*C*H*W, teacher's feature map
inputs: The inputs with gt bbox and input shape info. inputs: The inputs with gt bbox and input shape info.
""" """
assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \ assert preds_S.shape[-2:] == preds_T.shape[-2:], \
f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.' f'The shape of Student feature {preds_S.shape} and Teacher feature {preds_T.shape} should be the same.'
assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
gt_bboxes = inputs['gt_bbox'] gt_bboxes = inputs['gt_bbox']
ins_shape = [ assert len(gt_bboxes) == preds_S.shape[0], "error"
inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
] # select index
index_gt = []
for i in range(len(gt_bboxes)):
if gt_bboxes[i].size > 2:
index_gt.append(i)
index_gt_t = paddle.to_tensor(index_gt) # to tensor
preds_S = paddle.index_select(preds_S, index_gt_t)
preds_T = paddle.index_select(preds_T, index_gt_t)
assert preds_S.shape == preds_T.shape, "error"
img_metas_tmp = [{
'img_shape': inputs['im_shape'][i]
} for i in range(inputs['im_shape'].shape[0])]
img_metas = [img_metas_tmp[c] for c in index_gt]
gt_bboxes = [gt_bboxes[c] for c in index_gt]
assert len(gt_bboxes) == len(img_metas), "error"
assert len(gt_bboxes) == preds_T.shape[0], "error"
if self.align is not None: if self.align is not None:
stu_feature = self.align(stu_feature) preds_S = self.align(preds_S)
N, C, H, W = stu_feature.shape N, C, H, W = preds_S.shape
tea_spatial_att, tea_channel_att = self.spatial_channel_attention( S_attention_t, C_attention_t = self.gc_block(preds_T, self.temp)
tea_feature, self.temp) S_attention_s, C_attention_s = self.gc_block(preds_S, self.temp)
stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
stu_feature, self.temp)
Mask_fg = paddle.zeros(tea_spatial_att.shape) Mask_fg = paddle.zeros(S_attention_t.shape)
Mask_bg = paddle.ones_like(tea_spatial_att) Mask_bg = paddle.ones_like(S_attention_t)
one_tmp = paddle.ones([*tea_spatial_att.shape[1:]]) one_tmp = paddle.ones([*S_attention_t.shape[1:]])
zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]]) zero_tmp = paddle.zeros([*S_attention_t.shape[1:]])
wmin, wmax, hmin, hmax, area = [], [], [], [], [] wmin, wmax, hmin, hmax, area = [], [], [], [], []
for i in range(N): for i in range(N):
tmp_box = paddle.ones_like(gt_bboxes[i]) new_boxxes = paddle.ones_like(gt_bboxes[i])
tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][
tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W 1] * W
tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][
tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H 1] * W
new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][
zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32") 0] * H
ones = paddle.ones_like(tmp_box[:, 2], dtype="int32") new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][
0] * H
zero = paddle.zeros_like(new_boxxes[:, 0], dtype="int32")
ones = paddle.ones_like(new_boxxes[:, 2], dtype="int32")
wmin.append( wmin.append(
paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero)) paddle.cast(paddle.floor(new_boxxes[:, 0]), "int32").maximum(
wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32")) zero))
wmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 2]), "int32"))
hmin.append( hmin.append(
paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero)) paddle.cast(paddle.floor(new_boxxes[:, 1]), "int32").maximum(
hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32")) zero))
hmax.append(paddle.cast(paddle.ceil(new_boxxes[:, 3]), "int32"))
area_recip = 1.0 / ( area = 1.0 / (
hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / ( hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / (
wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1])) wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
for j in range(len(gt_bboxes[i])): for j in range(len(gt_bboxes[i])):
Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j], Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
hmax[i][j] + 1, wmin[i][j], hmax[i][j] + 1, wmin[i][j],
wmax[i][j] + 1, area_recip[0][j]) wmax[i][j] + 1, area[0][j])
Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp) Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
if paddle.sum(Mask_bg[i]): if paddle.sum(Mask_bg[i]):
Mask_bg[i] /= paddle.sum(Mask_bg[i]) Mask_bg[i] /= paddle.sum(Mask_bg[i])
fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg, fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
Mask_bg, tea_channel_att, C_attention_s, C_attention_t,
tea_spatial_att) S_attention_s, S_attention_t)
mask_loss = self.mask_loss(stu_channel_att, tea_channel_att, mask_loss = self.get_mask_loss(C_attention_s, C_attention_t,
stu_spatial_att, tea_spatial_att) S_attention_s, S_attention_t)
rela_loss = self.relation_loss(stu_feature, tea_feature) rela_loss = self.get_rela_loss(preds_S, preds_T)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册