From 975e8d4f4361dc794c59fee69d8729d8c46e4d2a Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Fri, 4 Nov 2022 16:11:21 +0800 Subject: [PATCH] [dev] fix shared weights in ppyoloe head (#7265) --- ppdet/modeling/heads/ppyoloe_head.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 279412066..5986b9c33 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -47,7 +47,8 @@ class ESEAttn(nn.Layer): @register class PPYOLOEHead(nn.Layer): __shared__ = [ - 'num_classes', 'eval_size', 'trt', 'exclude_nms', 'exclude_post_process' + 'num_classes', 'eval_size', 'trt', 'exclude_nms', + 'exclude_post_process', 'use_shared_conv' ] __inject__ = ['static_assigner', 'assigner', 'nms'] @@ -72,7 +73,8 @@ class PPYOLOEHead(nn.Layer): }, trt=False, exclude_nms=False, - exclude_post_process=False): + exclude_post_process=False, + use_shared_conv=True): super(PPYOLOEHead, self).__init__() assert len(in_channels) > 0, "len(in_channels) should > 0" self.in_channels = in_channels @@ -94,6 +96,8 @@ class PPYOLOEHead(nn.Layer): self.nms.trt = trt self.exclude_nms = exclude_nms self.exclude_post_process = exclude_post_process + self.use_shared_conv = use_shared_conv + # stem self.stem_cls = nn.LayerList() self.stem_reg = nn.LayerList() @@ -200,14 +204,22 @@ class PPYOLOEHead(nn.Layer): reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat)) reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose( [0, 2, 3, 1]) - reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)).squeeze(1) + if self.use_shared_conv: + reg_dist = self.proj_conv(F.softmax( + reg_dist, axis=1)).squeeze(1) + else: + reg_dist = F.softmax(reg_dist, axis=1) # cls and reg cls_score = F.sigmoid(cls_logit) cls_score_list.append(cls_score.reshape([-1, self.num_classes, l])) reg_dist_list.append(reg_dist) cls_score_list = paddle.concat(cls_score_list, axis=-1) - reg_dist_list = paddle.concat(reg_dist_list, axis=1) + if self.use_shared_conv: + reg_dist_list = paddle.concat(reg_dist_list, axis=1) + else: + reg_dist_list = paddle.concat(reg_dist_list, axis=2) + reg_dist_list = self.proj_conv(reg_dist_list).squeeze(1) return cls_score_list, reg_dist_list, anchor_points, stride_tensor -- GitLab