Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ba2aad26
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ba2aad26
编写于
10月 31, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
10月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dynamic shape of reshape op when export model (#7230)
上级
b0620a7b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
12 addition
and
11 deletion
+12
-11
ppdet/modeling/heads/gfl_head.py
ppdet/modeling/heads/gfl_head.py
+1
-1
ppdet/modeling/heads/pico_head.py
ppdet/modeling/heads/pico_head.py
+7
-6
ppdet/modeling/heads/ppyoloe_head.py
ppdet/modeling/heads/ppyoloe_head.py
+4
-4
未找到文件。
ppdet/modeling/heads/gfl_head.py
浏览文件 @
ba2aad26
...
@@ -260,7 +260,7 @@ class GFLHead(nn.Layer):
...
@@ -260,7 +260,7 @@ class GFLHead(nn.Layer):
center_points
=
paddle
.
stack
([
x
,
y
],
axis
=-
1
)
center_points
=
paddle
.
stack
([
x
,
y
],
axis
=-
1
)
cls_score
=
cls_score
.
reshape
([
b
,
-
1
,
self
.
cls_out_channels
])
cls_score
=
cls_score
.
reshape
([
b
,
-
1
,
self
.
cls_out_channels
])
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
*
stride
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
*
stride
bbox_pred
=
bbox_pred
.
reshape
([
b
,
cell_h
*
cell_w
,
4
])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
cell_h
*
cell_w
,
4
])
# NOTE: If keep_ratio=False and image shape value that
# NOTE: If keep_ratio=False and image shape value that
# multiples of 32, distance2bbox not set max_shapes parameter
# multiples of 32, distance2bbox not set max_shapes parameter
...
...
ppdet/modeling/heads/pico_head.py
浏览文件 @
ba2aad26
...
@@ -353,13 +353,13 @@ class PicoHead(OTAVFLHead):
...
@@ -353,13 +353,13 @@ class PicoHead(OTAVFLHead):
bbox_pred
=
bbox_pred
.
reshape
([
1
,
(
self
.
reg_max
+
1
)
*
4
,
bbox_pred
=
bbox_pred
.
reshape
([
1
,
(
self
.
reg_max
+
1
)
*
4
,
-
1
]).
transpose
([
0
,
2
,
1
])
-
1
]).
transpose
([
0
,
2
,
1
])
else
:
else
:
b
,
_
,
h
,
w
=
fpn_feat
.
shape
_
,
_
,
h
,
w
=
fpn_feat
.
shape
l
=
h
*
w
l
=
h
*
w
cls_score_out
=
F
.
sigmoid
(
cls_score_out
=
F
.
sigmoid
(
cls_score
.
reshape
([
b
,
self
.
cls_out_channels
,
l
]))
cls_score
.
reshape
([
-
1
,
self
.
cls_out_channels
,
l
]))
bbox_pred
=
bbox_pred
.
transpose
([
0
,
2
,
3
,
1
])
bbox_pred
=
bbox_pred
.
transpose
([
0
,
2
,
3
,
1
])
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
bbox_pred
=
bbox_pred
.
reshape
([
b
,
l
,
4
])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
l
,
4
])
cls_logits_list
.
append
(
cls_score_out
)
cls_logits_list
.
append
(
cls_score_out
)
bboxes_reg_list
.
append
(
bbox_pred
)
bboxes_reg_list
.
append
(
bbox_pred
)
...
@@ -597,7 +597,7 @@ class PicoHeadV2(GFLHead):
...
@@ -597,7 +597,7 @@ class PicoHeadV2(GFLHead):
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
(
fpn_feats
)
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
(
fpn_feats
)
cls_score_list
,
box_list
=
[],
[]
cls_score_list
,
box_list
=
[],
[]
for
i
,
(
fpn_feat
,
stride
)
in
enumerate
(
zip
(
fpn_feats
,
self
.
fpn_stride
)):
for
i
,
(
fpn_feat
,
stride
)
in
enumerate
(
zip
(
fpn_feats
,
self
.
fpn_stride
)):
b
,
_
,
h
,
w
=
fpn_feat
.
shape
_
,
_
,
h
,
w
=
fpn_feat
.
shape
# task decomposition
# task decomposition
conv_cls_feat
,
se_feat
=
self
.
conv_feat
(
fpn_feat
,
i
)
conv_cls_feat
,
se_feat
=
self
.
conv_feat
(
fpn_feat
,
i
)
cls_logit
=
self
.
head_cls_list
[
i
](
se_feat
)
cls_logit
=
self
.
head_cls_list
[
i
](
se_feat
)
...
@@ -620,10 +620,11 @@ class PicoHeadV2(GFLHead):
...
@@ -620,10 +620,11 @@ class PicoHeadV2(GFLHead):
[
0
,
2
,
1
]))
[
0
,
2
,
1
]))
else
:
else
:
l
=
h
*
w
l
=
h
*
w
cls_score_out
=
cls_score
.
reshape
([
b
,
self
.
cls_out_channels
,
l
])
cls_score_out
=
cls_score
.
reshape
(
[
-
1
,
self
.
cls_out_channels
,
l
])
bbox_pred
=
reg_pred
.
transpose
([
0
,
2
,
3
,
1
])
bbox_pred
=
reg_pred
.
transpose
([
0
,
2
,
3
,
1
])
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
bbox_pred
=
self
.
distribution_project
(
bbox_pred
)
bbox_pred
=
bbox_pred
.
reshape
([
b
,
l
,
4
])
bbox_pred
=
bbox_pred
.
reshape
([
-
1
,
l
,
4
])
cls_score_list
.
append
(
cls_score_out
)
cls_score_list
.
append
(
cls_score_out
)
box_list
.
append
(
bbox_pred
)
box_list
.
append
(
bbox_pred
)
...
...
ppdet/modeling/heads/ppyoloe_head.py
浏览文件 @
ba2aad26
...
@@ -192,7 +192,7 @@ class PPYOLOEHead(nn.Layer):
...
@@ -192,7 +192,7 @@ class PPYOLOEHead(nn.Layer):
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
(
feats
)
anchor_points
,
stride_tensor
=
self
.
_generate_anchors
(
feats
)
cls_score_list
,
reg_dist_list
=
[],
[]
cls_score_list
,
reg_dist_list
=
[],
[]
for
i
,
feat
in
enumerate
(
feats
):
for
i
,
feat
in
enumerate
(
feats
):
b
,
_
,
h
,
w
=
feat
.
shape
_
,
_
,
h
,
w
=
feat
.
shape
l
=
h
*
w
l
=
h
*
w
avg_feat
=
F
.
adaptive_avg_pool2d
(
feat
,
(
1
,
1
))
avg_feat
=
F
.
adaptive_avg_pool2d
(
feat
,
(
1
,
1
))
cls_logit
=
self
.
pred_cls
[
i
](
self
.
stem_cls
[
i
](
feat
,
avg_feat
)
+
cls_logit
=
self
.
pred_cls
[
i
](
self
.
stem_cls
[
i
](
feat
,
avg_feat
)
+
...
@@ -203,7 +203,7 @@ class PPYOLOEHead(nn.Layer):
...
@@ -203,7 +203,7 @@ class PPYOLOEHead(nn.Layer):
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
)).
squeeze
(
1
)
reg_dist
=
self
.
proj_conv
(
F
.
softmax
(
reg_dist
,
axis
=
1
)).
squeeze
(
1
)
# cls and reg
# cls and reg
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score
=
F
.
sigmoid
(
cls_logit
)
cls_score_list
.
append
(
cls_score
.
reshape
([
b
,
self
.
num_classes
,
l
]))
cls_score_list
.
append
(
cls_score
.
reshape
([
-
1
,
self
.
num_classes
,
l
]))
reg_dist_list
.
append
(
reg_dist
)
reg_dist_list
.
append
(
reg_dist
)
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=-
1
)
cls_score_list
=
paddle
.
concat
(
cls_score_list
,
axis
=-
1
)
...
@@ -238,8 +238,8 @@ class PPYOLOEHead(nn.Layer):
...
@@ -238,8 +238,8 @@ class PPYOLOEHead(nn.Layer):
return
loss
return
loss
def
_bbox_decode
(
self
,
anchor_points
,
pred_dist
):
def
_bbox_decode
(
self
,
anchor_points
,
pred_dist
):
b
,
l
,
_
=
get_static_shape
(
pred_dist
)
_
,
l
,
_
=
get_static_shape
(
pred_dist
)
pred_dist
=
F
.
softmax
(
pred_dist
.
reshape
([
b
,
l
,
4
,
self
.
reg_max
+
1
]))
pred_dist
=
F
.
softmax
(
pred_dist
.
reshape
([
-
1
,
l
,
4
,
self
.
reg_max
+
1
]))
pred_dist
=
self
.
proj_conv
(
pred_dist
.
transpose
([
0
,
3
,
1
,
2
])).
squeeze
(
1
)
pred_dist
=
self
.
proj_conv
(
pred_dist
.
transpose
([
0
,
3
,
1
,
2
])).
squeeze
(
1
)
return
batch_distance2bbox
(
anchor_points
,
pred_dist
)
return
batch_distance2bbox
(
anchor_points
,
pred_dist
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录