Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
3672d1f2
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3672d1f2
编写于
1月 17, 2023
作者:
weixin_46524038
提交者:
cuicheng01
1月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add swinV1 22k weights
上级
5544dbaf
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
83 addition
and
37 deletion
+83
-37
ppcls/arch/backbone/legendary_models/swin_transformer.py
ppcls/arch/backbone/legendary_models/swin_transformer.py
+83
-37
未找到文件。
ppcls/arch/backbone/legendary_models/swin_transformer.py
浏览文件 @
3672d1f2
...
@@ -35,9 +35,9 @@ MODEL_URLS = {
...
@@ -35,9 +35,9 @@ MODEL_URLS = {
"SwinTransformer_base_patch4_window12_384"
:
"SwinTransformer_base_patch4_window12_384"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_base_patch4_window12_384_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_base_patch4_window12_384_pretrained.pdparams"
,
"SwinTransformer_large_patch4_window7_224"
:
"SwinTransformer_large_patch4_window7_224"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window7_224_
22kto1k_
pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window7_224_pretrained.pdparams"
,
"SwinTransformer_large_patch4_window12_384"
:
"SwinTransformer_large_patch4_window12_384"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window12_384_
22kto1k_
pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SwinTransformer_large_patch4_window12_384_pretrained.pdparams"
,
}
}
__all__
=
list
(
MODEL_URLS
.
keys
())
__all__
=
list
(
MODEL_URLS
.
keys
())
...
@@ -45,13 +45,15 @@ __all__ = list(MODEL_URLS.keys())
...
@@ -45,13 +45,15 @@ __all__ = list(MODEL_URLS.keys())
# The following re-implementation of roll is inspired by
# The following re-implementation of roll is inspired by
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
class
RollWithIndexSelect
(
paddle
.
autograd
.
PyLayer
):
class
RollWithIndexSelect
(
paddle
.
autograd
.
PyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input1
,
index_fp
,
index_bp
):
def
forward
(
ctx
,
input1
,
index_fp
,
index_bp
):
N
,
H
,
W
,
C
=
input1
.
shape
N
,
H
,
W
,
C
=
input1
.
shape
ctx
.
input1
=
input1
ctx
.
input1
=
input1
ctx
.
index_bp
=
index_bp
ctx
.
index_bp
=
index_bp
result
=
input1
.
reshape
([
N
,
H
*
W
,
C
]).
index_select
(
index_fp
,
1
).
reshape
([
N
,
H
,
W
,
C
])
result
=
input1
.
reshape
([
N
,
H
*
W
,
C
]).
index_select
(
index_fp
,
1
).
reshape
([
N
,
H
,
W
,
C
])
return
result
return
result
@
staticmethod
@
staticmethod
...
@@ -59,14 +61,15 @@ class RollWithIndexSelect(paddle.autograd.PyLayer):
...
@@ -59,14 +61,15 @@ class RollWithIndexSelect(paddle.autograd.PyLayer):
input1
=
ctx
.
input1
input1
=
ctx
.
input1
N
,
H
,
W
,
C
=
input1
.
shape
N
,
H
,
W
,
C
=
input1
.
shape
index_bp
=
ctx
.
index_bp
index_bp
=
ctx
.
index_bp
grad_input
=
grad
.
reshape
([
N
,
H
*
W
,
C
]).
index_select
(
index_bp
,
1
).
reshape
([
N
,
H
,
W
,
C
])
grad_input
=
grad
.
reshape
([
N
,
H
*
W
,
C
]).
index_select
(
index_bp
,
1
).
reshape
([
N
,
H
,
W
,
C
])
return
grad_input
,
None
,
None
return
grad_input
,
None
,
None
def
get_roll_index
(
H
,
W
,
shifts
,
place
):
def
get_roll_index
(
H
,
W
,
shifts
,
place
):
index
=
np
.
arange
(
0
,
H
*
W
,
dtype
=
np
.
int64
).
reshape
([
H
,
W
])
index
=
np
.
arange
(
0
,
H
*
W
,
dtype
=
np
.
int64
).
reshape
([
H
,
W
])
index_fp
=
np
.
roll
(
index
,
shift
=
shifts
,
axis
=
(
0
,
1
)).
reshape
([
-
1
])
index_fp
=
np
.
roll
(
index
,
shift
=
shifts
,
axis
=
(
0
,
1
)).
reshape
([
-
1
])
index_bp
=
{
i
:
idx
for
idx
,
i
in
enumerate
(
index_fp
.
tolist
())}
index_bp
=
{
i
:
idx
for
idx
,
i
in
enumerate
(
index_fp
.
tolist
())}
index_bp
=
[
index_bp
[
i
]
for
i
in
range
(
H
*
W
)]
index_bp
=
[
index_bp
[
i
]
for
i
in
range
(
H
*
W
)]
index_fp
=
paddle
.
to_tensor
(
index_fp
,
place
=
place
)
index_fp
=
paddle
.
to_tensor
(
index_fp
,
place
=
place
)
index_bp
=
paddle
.
to_tensor
(
index_fp
,
dtype
=
'int64'
,
place
=
place
)
index_bp
=
paddle
.
to_tensor
(
index_fp
,
dtype
=
'int64'
,
place
=
place
)
...
@@ -97,7 +100,9 @@ class RollWrapper(object):
...
@@ -97,7 +100,9 @@ class RollWrapper(object):
@
staticmethod
@
staticmethod
def
roll
(
x
,
shifts
,
axis
):
def
roll
(
x
,
shifts
,
axis
):
if
RollWrapper
.
_roll
is
None
:
if
RollWrapper
.
_roll
is
None
:
RollWrapper
.
_roll
=
NpuRollWithIndexSelect
()
if
'npu'
in
paddle
.
device
.
get_all_custom_device_type
()
else
paddle
.
roll
RollWrapper
.
_roll
=
NpuRollWithIndexSelect
(
)
if
'npu'
in
paddle
.
device
.
get_all_custom_device_type
(
)
else
paddle
.
roll
return
RollWrapper
.
_roll
(
x
,
shifts
,
axis
)
return
RollWrapper
.
_roll
(
x
,
shifts
,
axis
)
...
@@ -507,7 +512,7 @@ class PatchMerging(nn.Layer):
...
@@ -507,7 +512,7 @@ class PatchMerging(nn.Layer):
# x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
# x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
# x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
# x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x
=
x
.
reshape
([
B
,
H
//
2
,
2
,
W
//
2
,
2
,
C
])
x
=
x
.
reshape
([
B
,
H
//
2
,
2
,
W
//
2
,
2
,
C
])
x
=
x
.
transpose
((
0
,
1
,
3
,
4
,
2
,
5
))
x
=
x
.
transpose
((
0
,
1
,
3
,
4
,
2
,
5
))
x
=
x
.
reshape
([
B
,
H
*
W
//
4
,
4
*
C
])
# B H/2*W/2 4*C
x
=
x
.
reshape
([
B
,
H
*
W
//
4
,
4
*
C
])
# B H/2*W/2 4*C
...
@@ -703,7 +708,7 @@ class SwinTransformer(TheseusLayer):
...
@@ -703,7 +708,7 @@ class SwinTransformer(TheseusLayer):
img_size
=
224
,
img_size
=
224
,
patch_size
=
4
,
patch_size
=
4
,
in_chans
=
3
,
in_chans
=
3
,
class_num
=
1000
,
class_num
=
5
,
embed_dim
=
96
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
num_heads
=
[
3
,
6
,
12
,
24
],
...
@@ -822,11 +827,21 @@ class SwinTransformer(TheseusLayer):
...
@@ -822,11 +827,21 @@ class SwinTransformer(TheseusLayer):
return
flops
return
flops
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
):
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
False
):
if
pretrained
is
False
:
if
pretrained
is
False
:
pass
pass
elif
pretrained
is
True
:
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
elif
isinstance
(
pretrained
,
str
):
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
else
:
...
@@ -835,81 +850,105 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
...
@@ -835,81 +850,105 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
)
)
def
SwinTransformer_tiny_patch4_window7_224
(
pretrained
=
False
,
def
SwinTransformer_tiny_patch4_window7_224
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
False
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
embed_dim
=
96
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
window_size
=
7
,
drop_path_rate
=
0.2
,
drop_path_rate
=
0.2
,
# if imagenet22k or imagenet22kto1k, set drop_path_rate=0.1
**
kwargs
)
**
kwargs
)
_load_pretrained
(
_load_pretrained
(
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_tiny_patch4_window7_224"
],
MODEL_URLS
[
"SwinTransformer_tiny_patch4_window7_224"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
def
SwinTransformer_small_patch4_window7_224
(
pretrained
=
False
,
def
SwinTransformer_small_patch4_window7_224
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
False
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
embed_dim
=
96
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
18
,
2
],
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
window_size
=
7
,
drop_path_rate
=
0.3
,
# if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**
kwargs
)
**
kwargs
)
_load_pretrained
(
_load_pretrained
(
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_small_patch4_window7_224"
],
MODEL_URLS
[
"SwinTransformer_small_patch4_window7_224"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
def
SwinTransformer_base_patch4_window7_224
(
pretrained
=
False
,
def
SwinTransformer_base_patch4_window7_224
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
False
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
embed_dim
=
128
,
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
7
,
window_size
=
7
,
drop_path_rate
=
0.5
,
drop_path_rate
=
0.5
,
# if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**
kwargs
)
**
kwargs
)
_load_pretrained
(
_load_pretrained
(
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_base_patch4_window7_224"
],
MODEL_URLS
[
"SwinTransformer_base_patch4_window7_224"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
def
SwinTransformer_base_patch4_window12_384
(
pretrained
=
False
,
def
SwinTransformer_base_patch4_window12_384
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
False
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
img_size
=
384
,
img_size
=
384
,
embed_dim
=
128
,
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
12
,
window_size
=
12
,
drop_path_rate
=
0.5
,
#
NOTE: do not appear in offical code
drop_path_rate
=
0.5
,
#
if imagenet22k or imagenet22kto1k, set drop_path_rate=0.2
**
kwargs
)
**
kwargs
)
_load_pretrained
(
_load_pretrained
(
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_base_patch4_window12_384"
],
MODEL_URLS
[
"SwinTransformer_base_patch4_window12_384"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
def
SwinTransformer_large_patch4_window7_224
(
pretrained
=
False
,
def
SwinTransformer_large_patch4_window7_224
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
True
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
embed_dim
=
192
,
embed_dim
=
192
,
depths
=
[
2
,
2
,
18
,
2
],
depths
=
[
2
,
2
,
18
,
2
],
...
@@ -920,13 +959,18 @@ def SwinTransformer_large_patch4_window7_224(pretrained=False,
...
@@ -920,13 +959,18 @@ def SwinTransformer_large_patch4_window7_224(pretrained=False,
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_large_patch4_window7_224"
],
MODEL_URLS
[
"SwinTransformer_large_patch4_window7_224"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
def
SwinTransformer_large_patch4_window12_384
(
pretrained
=
False
,
def
SwinTransformer_large_patch4_window12_384
(
use_ssld
=
False
,
pretrained
=
False
,
**
kwargs
):
use_ssld
=
False
,
use_imagenet22k_pretrained
=
False
,
use_imagenet22kto1k_pretrained
=
True
,
**
kwargs
):
model
=
SwinTransformer
(
model
=
SwinTransformer
(
img_size
=
384
,
img_size
=
384
,
embed_dim
=
192
,
embed_dim
=
192
,
...
@@ -938,5 +982,7 @@ def SwinTransformer_large_patch4_window12_384(pretrained=False,
...
@@ -938,5 +982,7 @@ def SwinTransformer_large_patch4_window12_384(pretrained=False,
pretrained
,
pretrained
,
model
,
model
,
MODEL_URLS
[
"SwinTransformer_large_patch4_window12_384"
],
MODEL_URLS
[
"SwinTransformer_large_patch4_window12_384"
],
use_ssld
=
use_ssld
)
use_ssld
=
use_ssld
,
use_imagenet22k_pretrained
=
use_imagenet22k_pretrained
,
use_imagenet22kto1k_pretrained
=
use_imagenet22kto1k_pretrained
)
return
model
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录