Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
975e8d4f
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
975e8d4f
编写于
11月 04, 2022
作者:
S
shangliang Xu
提交者:
GitHub
11月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] fix shared weights in ppyoloe head (#7265)
上级
5e4d3ccc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
16 addition
and
4 deletion
+16
-4
ppdet/modeling/heads/ppyoloe_head.py
ppdet/modeling/heads/ppyoloe_head.py
+16
-4
未找到文件。
ppdet/modeling/heads/ppyoloe_head.py
浏览文件 @
975e8d4f
...
...
@@ -47,7 +47,8 @@ class ESEAttn(nn.Layer):
@
register
class
PPYOLOEHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'eval_size'
,
'trt'
,
'exclude_nms'
,
'exclude_post_process'
'num_classes'
,
'eval_size'
,
'trt'
,
'exclude_nms'
,
'exclude_post_process'
,
'use_shared_conv'
]
__inject__
=
[
'static_assigner'
,
'assigner'
,
'nms'
]
...
...
@@ -72,7 +73,8 @@ class PPYOLOEHead(nn.Layer):
},
trt
=
False
,
exclude_nms
=
False
,
exclude_post_process
=
False
):
exclude_post_process
=
False
,
use_shared_conv
=
True
):
super
(
PPYOLOEHead
,
self
).
__init__
()
assert
len
(
in_channels
)
>
0
,
"len(in_channels) should > 0"
self
.
in_channels
=
in_channels
...
...
@@ -94,6 +96,8 @@ class PPYOLOEHead(nn.Layer):
self
.
nms
.
trt
=
trt
self
.
exclude_nms
=
exclude_nms
self
.
exclude_post_process
=
exclude_post_process
self
.
use_shared_conv
=
use_shared_conv
# stem
self
.
stem_cls
=
nn
.
LayerList
()
self
.
stem_reg
=
nn
.
LayerList
()
...
...
@@ -200,14 +204,22 @@ class PPYOLOEHead(nn.Layer):
reg_dist
=
self
.
pred_reg
[
i
](
self
.
stem_reg
[
i
](
feat
,
avg_feat
))
reg_dist
=
reg_dist
.
reshape
([
-
1
,
4
,
self
.
reg_max
+
1
,
l
]).
transpose
(
[
0
,
2
,
3
,
1
])
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
)).
squeeze
(
1
)
if
self
.
use_shared_conv
:
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
)).
squeeze
(
1
)
else
:
reg_dist
=
F
.
softmax
(
reg_dist
,
axis
=
1
)
# cls and reg
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score_list
.
append
(
cls_score
.
reshape
([
-
1
,
self
.
num_classes
,
l
]))
reg_dist_list
.
append
(
reg_dist
)
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=-
1
)
reg_dist_list
=
paddle
.
concat
(
reg_dist_list
,
axis
=
1
)
if
self
.
use_shared_conv
:
reg_dist_list
=
paddle
.
concat
(
reg_dist_list
,
axis
=
1
)
else
:
reg_dist_list
=
paddle
.
concat
(
reg_dist_list
,
axis
=
2
)
reg_dist_list
=
self
.
proj_conv
(
reg_dist_list
).
squeeze
(
1
)
return
cls_score_list
,
reg_dist_list
,
anchor_points
,
stride_tensor
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录