Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
a27df36d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a27df36d
编写于
11月 07, 2022
作者:
W
Wenyu
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix arange to static for inference (#7279)
上级
e639b354
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
25 addition
and
7 deletion
+25
-7
ppdet/modeling/necks/custom_pan.py
ppdet/modeling/necks/custom_pan.py
+25
-7
未找到文件。
ppdet/modeling/necks/custom_pan.py
浏览文件 @
a27df36d
...
...
@@ -184,7 +184,10 @@ class TransformerEncoder(nn.Layer):
@
register
@
serializable
class
CustomCSPPAN
(
nn
.
Layer
):
__shared__
=
[
'norm_type'
,
'data_format'
,
'width_mult'
,
'depth_mult'
,
'trt'
]
__shared__
=
[
'norm_type'
,
'data_format'
,
'width_mult'
,
'depth_mult'
,
'trt'
,
'eval_size'
]
def
__init__
(
self
,
in_channels
=
[
256
,
512
,
1024
],
...
...
@@ -212,7 +215,8 @@ class CustomCSPPAN(nn.Layer):
attn_dropout
=
None
,
act_dropout
=
None
,
normalize_before
=
False
,
use_trans
=
False
):
use_trans
=
False
,
eval_size
=
None
):
super
(
CustomCSPPAN
,
self
).
__init__
()
out_channels
=
[
max
(
round
(
c
*
width_mult
),
1
)
for
c
in
out_channels
]
...
...
@@ -223,19 +227,29 @@ class CustomCSPPAN(nn.Layer):
self
.
num_blocks
=
len
(
in_channels
)
self
.
data_format
=
data_format
self
.
_out_channels
=
out_channels
self
.
hidden_dim
=
in_channels
[
-
1
]
in_channels
=
in_channels
[::
-
1
]
self
.
nhead
=
nhead
self
.
num_layers
=
num_layers
self
.
use_trans
=
use_trans
self
.
eval_size
=
eval_size
if
use_trans
:
if
eval_size
is
not
None
:
self
.
pos_embed
=
self
.
build_2d_sincos_position_embedding
(
eval_size
[
1
]
//
32
,
eval_size
[
0
]
//
32
,
embed_dim
=
self
.
hidden_dim
)
else
:
self
.
pos_embed
=
None
encoder_layer
=
TransformerEncoderLayer
(
self
.
hidden_dim
,
nhead
,
dim_feedforward
,
dropout
,
activation
,
attn_dropout
,
act_dropout
,
normalize_before
)
encoder_norm
=
nn
.
LayerNorm
(
self
.
hidden_dim
)
if
normalize_before
else
None
self
.
encoder
=
TransformerEncoder
(
encoder_layer
,
self
.
num_layers
,
self
.
encoder
=
TransformerEncoder
(
encoder_layer
,
num_layers
,
encoder_norm
)
fpn_stages
=
[]
fpn_routes
=
[]
for
i
,
(
ch_in
,
ch_out
)
in
enumerate
(
zip
(
in_channels
,
out_channels
)):
...
...
@@ -340,8 +354,12 @@ class CustomCSPPAN(nn.Layer):
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten
=
last_feat
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
pos_embed
=
self
.
build_2d_sincos_position_embedding
(
w
=
w
,
h
=
h
,
embed_dim
=
self
.
hidden_dim
)
if
self
.
eval_size
is
not
None
:
pos_embed
=
self
.
pos_embed
else
:
pos_embed
=
self
.
build_2d_sincos_position_embedding
(
w
=
w
,
h
=
h
,
embed_dim
=
self
.
hidden_dim
)
memory
=
self
.
encoder
(
src_flatten
,
pos_embed
=
pos_embed
)
last_feat_encode
=
memory
.
transpose
([
0
,
2
,
1
]).
reshape
([
n
,
c
,
h
,
w
])
blocks
[
-
1
]
=
last_feat_encode
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录