未验证 提交 975e8d4f 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix shared weights in ppyoloe head (#7265)

上级 5e4d3ccc
......@@ -47,7 +47,8 @@ class ESEAttn(nn.Layer):
@register
class PPYOLOEHead(nn.Layer):
__shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms', 'exclude_post_process'
'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv'
]
__inject__ = ['static_assigner', 'assigner', 'nms']
......@@ -72,7 +73,8 @@ class PPYOLOEHead(nn.Layer):
},
trt=False,
exclude_nms=False,
exclude_post_process=False):
exclude_post_process=False,
use_shared_conv=True):
super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
......@@ -94,6 +96,8 @@ class PPYOLOEHead(nn.Layer):
self.nms.trt = trt
self.exclude_nms = exclude_nms
self.exclude_post_process = exclude_post_process
self.use_shared_conv = use_shared_conv
# stem
self.stem_cls = nn.LayerList()
self.stem_reg = nn.LayerList()
......@@ -200,14 +204,22 @@ class PPYOLOEHead(nn.Layer):
reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
[0, 2, 3, 1])
reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)).squeeze(1)
if self.use_shared_conv:
reg_dist = self.proj_conv(F.softmax(
reg_dist, axis=1)).squeeze(1)
else:
reg_dist = F.softmax(reg_dist, axis=1)
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.reshape([-1, self.num_classes, l]))
reg_dist_list.append(reg_dist)
cls_score_list = paddle.concat(cls_score_list, axis=-1)
reg_dist_list = paddle.concat(reg_dist_list, axis=1)
if self.use_shared_conv:
reg_dist_list = paddle.concat(reg_dist_list, axis=1)
else:
reg_dist_list = paddle.concat(reg_dist_list, axis=2)
reg_dist_list = self.proj_conv(reg_dist_list).squeeze(1)
return cls_score_list, reg_dist_list, anchor_points, stride_tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册