未验证 提交 99f16815 编写于 作者: D Double_V 提交者: GitHub

fix loss nan and support picodet with fgd (#6420)

上级 8d22b60a
...@@ -2,11 +2,12 @@ _BASE_: [ ...@@ -2,11 +2,12 @@ _BASE_: [
'../../retinanet/retinanet_r101_fpn_2x_coco.yml', '../../retinanet/retinanet_r101_fpn_2x_coco.yml',
] ]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams
slim: Distill slim: Distill
slim_method: FGD slim_method: FGD
distill_loss: FGDFeatureLoss distill_loss: FGDFeatureLoss
distill_loss_name: ['neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1', 'neck_f_0']
FGDFeatureLoss: FGDFeatureLoss:
student_channels: 256 student_channels: 256
......
...@@ -74,14 +74,16 @@ class FGDDistillModel(nn.Layer): ...@@ -74,14 +74,16 @@ class FGDDistillModel(nn.Layer):
def __init__(self, cfg, slim_cfg): def __init__(self, cfg, slim_cfg):
super(FGDDistillModel, self).__init__() super(FGDDistillModel, self).__init__()
self.student_cfg = cfg
self.is_inherit = True
# build student model before load slim config
self.student_model = create(cfg.architecture)
self.arch = cfg.architecture
stu_pretrain = cfg['pretrain_weights']
slim_cfg = load_config(slim_cfg) slim_cfg = load_config(slim_cfg)
self.teacher_cfg = slim_cfg self.teacher_cfg = slim_cfg
self.loss_cfg = slim_cfg self.loss_cfg = slim_cfg
self.is_loaded_weights = True tea_pretrain = cfg['pretrain_weights']
self.is_inherit = True
self.student_model = create(self.student_cfg.architecture)
self.teacher_model = create(self.teacher_cfg.architecture) self.teacher_model = create(self.teacher_cfg.architecture)
self.teacher_model.eval() self.teacher_model.eval()
...@@ -89,29 +91,22 @@ class FGDDistillModel(nn.Layer): ...@@ -89,29 +91,22 @@ class FGDDistillModel(nn.Layer):
for param in self.teacher_model.parameters(): for param in self.teacher_model.parameters():
param.trainable = False param.trainable = False
if 'pretrain_weights' in self.student_cfg and self.student_cfg.pretrain_weights: if 'pretrain_weights' in cfg and stu_pretrain:
if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
self._load_pretrain_weights(self.student_model, load_pretrain_weight(self.student_model,
self.teacher_cfg.pretrain_weights) self.teacher_cfg.pretrain_weights)
print("loading teacher weights to student model!") logger.debug(
"Inheriting! loading teacher weights to student model!")
self._load_pretrain_weights(self.student_model, load_pretrain_weight(self.student_model, stu_pretrain)
self.student_cfg.pretrain_weights)
print("loading student model Done")
if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights: if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
self._load_pretrain_weights(self.teacher_model, load_pretrain_weight(self.teacher_model,
self.teacher_cfg.pretrain_weights) self.teacher_cfg.pretrain_weights)
print("loading teacher model Done")
self.fgd_loss_dic = self.build_loss(self.loss_cfg.distill_loss) self.fgd_loss_dic = self.build_loss(
self.loss_cfg.distill_loss,
def _load_pretrain_weights(self, model, weights): name_list=self.loss_cfg['distill_loss_name'])
if self.is_loaded_weights:
return
self.start_epoch = 0
load_pretrain_weight(model, weights)
logger.debug("Load weights {} to start training".format(weights))
def build_loss(self, def build_loss(self,
cfg, cfg,
...@@ -137,20 +132,28 @@ class FGDDistillModel(nn.Layer): ...@@ -137,20 +132,28 @@ class FGDDistillModel(nn.Layer):
for idx, k in enumerate(self.fgd_loss_dic): for idx, k in enumerate(self.fgd_loss_dic):
loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx], loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
t_neck_feats[idx], inputs) t_neck_feats[idx], inputs)
if self.arch == "RetinaNet":
loss = self.student_model.head(s_neck_feats, inputs) loss = self.student_model.head(s_neck_feats, inputs)
elif self.arch == "PicoDet":
loss = self.student_model.get_loss()
else:
raise ValueError(f"Unsupported model {self.arch}")
for k in loss_dict: for k in loss_dict:
loss['loss'] += loss_dict[k] loss['loss'] += loss_dict[k]
loss[k] = loss_dict[k] loss[k] = loss_dict[k]
return loss return loss
else: else:
body_feats = self.student_model.backbone(inputs) body_feats = self.student_model.backbone(inputs)
neck_feats = self.student_model.neck(body_feats) neck_feats = self.student_model.neck(body_feats)
head_outs = self.student_model.head(neck_feats) head_outs = self.student_model.head(neck_feats)
if self.arch == "RetinaNet":
bbox, bbox_num = self.student_model.head.post_process( bbox, bbox_num = self.student_model.head.post_process(
head_outs, inputs['im_shape'], inputs['scale_factor']) head_outs, inputs['im_shape'], inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num} return {'bbox': bbox, 'bbox_num': bbox_num}
elif self.arch == "PicoDet":
return self.student_model.head.get_pred()
else:
raise ValueError(f"Unsupported model {self.arch}")
@register @register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册