Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
e0767460
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e0767460
编写于
7月 21, 2021
作者:
T
Tingquan Gao
提交者:
GitHub
7月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update gvt.py
上级
88d0d4ca
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
12 deletion
+12
-12
ppcls/arch/backbone/model_zoo/gvt.py
ppcls/arch/backbone/model_zoo/gvt.py
+12
-12
未找到文件。
ppcls/arch/backbone/model_zoo/gvt.py
浏览文件 @
e0767460
...
...
@@ -78,9 +78,9 @@ class GroupAttention(nn.Layer):
total_groups
=
h_group
*
w_group
x
=
x
.
reshape
([
B
,
h_group
,
self
.
ws
,
w_group
,
self
.
ws
,
C
]).
transpose
(
[
0
,
1
,
3
,
2
,
4
,
5
])
qkv
=
self
.
qkv
(
x
).
reshape
(
[
B
,
total_groups
,
-
1
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
([
3
,
0
,
1
,
4
,
2
,
5
])
qkv
=
self
.
qkv
(
x
).
reshape
(
[
B
,
total_groups
,
self
.
ws
**
2
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
([
3
,
0
,
1
,
4
,
2
,
5
])
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
@
k
.
transpose
([
0
,
1
,
2
,
4
,
3
]))
*
self
.
scale
...
...
@@ -135,14 +135,15 @@ class Attention(nn.Layer):
if
self
.
sr_ratio
>
1
:
x_
=
x
.
transpose
([
0
,
2
,
1
]).
reshape
([
B
,
C
,
H
,
W
])
x_
=
self
.
sr
(
x_
).
reshape
([
B
,
C
,
-
1
]).
transpose
([
0
,
2
,
1
])
tmp_n
=
H
*
W
//
self
.
sr_ratio
**
2
x_
=
self
.
sr
(
x_
).
reshape
([
B
,
C
,
tmp_n
]).
transpose
([
0
,
2
,
1
])
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
[
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
(
[
B
,
tmp_n
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
(
[
2
,
0
,
3
,
1
,
4
])
else
:
kv
=
self
.
kv
(
x
).
reshape
(
[
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
(
[
B
,
N
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
(
[
2
,
0
,
3
,
1
,
4
])
k
,
v
=
kv
[
0
],
kv
[
1
]
...
...
@@ -317,7 +318,6 @@ class PyramidVisionTransformer(nn.Layer):
self
.
create_parameter
(
shape
=
[
1
,
patch_num
,
embed_dims
[
i
]],
default_initializer
=
zeros_
))
self
.
add_parameter
(
f
"pos_embeds_
{
i
}
"
,
self
.
pos_embeds
[
i
])
self
.
pos_drops
.
append
(
nn
.
Dropout
(
p
=
drop_rate
))
dpr
=
[
...
...
@@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer):
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
class_num
=
1000
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
...
...
@@ -446,7 +446,7 @@ class CPVTV2(PyramidVisionTransformer):
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
block_cls
=
Block
):
super
().
__init__
(
img_size
,
patch_size
,
in_chans
,
class_num
,
super
().
__init__
(
img_size
,
patch_size
,
in_chans
,
num_classes
,
embed_dims
,
num_heads
,
mlp_ratios
,
qkv_bias
,
qk_scale
,
drop_rate
,
attn_drop_rate
,
drop_path_rate
,
norm_layer
,
depths
,
sr_ratios
,
block_cls
)
...
...
@@ -488,7 +488,7 @@ class CPVTV2(PyramidVisionTransformer):
x
=
self
.
pos_block
[
i
](
x
,
H
,
W
)
# PEG here
if
i
<
len
(
self
.
depths
)
-
1
:
x
=
x
.
reshape
([
B
,
H
,
W
,
-
1
]).
transpose
([
0
,
3
,
1
,
2
])
x
=
x
.
reshape
([
B
,
H
,
W
,
x
.
shape
[
-
1
]
]).
transpose
([
0
,
3
,
1
,
2
])
x
=
self
.
norm
(
x
)
return
x
.
mean
(
axis
=
1
)
# GAP here
...
...
@@ -499,7 +499,7 @@ class PCPVT(CPVTV2):
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
class_num
=
1000
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
],
num_heads
=
[
1
,
2
,
4
],
mlp_ratios
=
[
4
,
4
,
4
],
...
...
@@ -512,7 +512,7 @@ class PCPVT(CPVTV2):
depths
=
[
4
,
4
,
4
],
sr_ratios
=
[
4
,
2
,
1
],
block_cls
=
SBlock
):
super
().
__init__
(
img_size
,
patch_size
,
in_chans
,
class_num
,
super
().
__init__
(
img_size
,
patch_size
,
in_chans
,
num_classes
,
embed_dims
,
num_heads
,
mlp_ratios
,
qkv_bias
,
qk_scale
,
drop_rate
,
attn_drop_rate
,
drop_path_rate
,
norm_layer
,
depths
,
sr_ratios
,
block_cls
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录