Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
b480d92c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b480d92c
编写于
4月 27, 2021
作者:
L
littletomatodonkey
提交者:
GitHub
4月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add swin transformer (#685)
* add swin transformer
上级
91502d8f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
1059 addition
and
0 deletion
+1059
-0
configs/SwinTransformer/SwinTransformer_base_patch4_window12_384.yaml
...Transformer/SwinTransformer_base_patch4_window12_384.yaml
+72
-0
configs/SwinTransformer/SwinTransformer_base_patch4_window7_224.yaml
...nTransformer/SwinTransformer_base_patch4_window7_224.yaml
+74
-0
configs/SwinTransformer/SwinTransformer_small_patch4_window7_224.yaml
...Transformer/SwinTransformer_small_patch4_window7_224.yaml
+74
-0
configs/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml
...nTransformer/SwinTransformer_tiny_patch4_window7_224.yaml
+74
-0
ppcls/modeling/architectures/__init__.py
ppcls/modeling/architectures/__init__.py
+1
-0
ppcls/modeling/architectures/swin_transformer.py
ppcls/modeling/architectures/swin_transformer.py
+764
-0
未找到文件。
configs/SwinTransformer/SwinTransformer_base_patch4_window12_384.yaml
0 → 100644
浏览文件 @
b480d92c
mode
:
'
train'
ARCHITECTURE
:
name
:
'
SwinTransformer_base_patch4_window12_384'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
384
,
384
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Piecewise'
params
:
lr
:
0.1
decay_epochs
:
[
30
,
60
,
90
]
gamma
:
0.1
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
384
,
384
]
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/SwinTransformer/SwinTransformer_base_patch4_window7_224.yaml
0 → 100644
浏览文件 @
b480d92c
mode
:
'
train'
ARCHITECTURE
:
name
:
'
SwinTransformer_base_patch4_window7_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Piecewise'
params
:
lr
:
0.1
decay_epochs
:
[
30
,
60
,
90
]
gamma
:
0.1
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/SwinTransformer/SwinTransformer_small_patch4_window7_224.yaml
0 → 100644
浏览文件 @
b480d92c
mode
:
'
train'
ARCHITECTURE
:
name
:
'
SwinTransformer_small_patch4_window7_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Piecewise'
params
:
lr
:
0.1
decay_epochs
:
[
30
,
60
,
90
]
gamma
:
0.1
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/SwinTransformer/SwinTransformer_tiny_patch4_window7_224.yaml
0 → 100644
浏览文件 @
b480d92c
mode
:
'
train'
ARCHITECTURE
:
name
:
'
SwinTransformer_tiny_patch4_window7_224'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
120
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
-1
LEARNING_RATE
:
function
:
'
Piecewise'
params
:
lr
:
0.1
decay_epochs
:
[
30
,
60
,
90
]
gamma
:
0.1
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.000100
TRAIN
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
256
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
ppcls/modeling/architectures/__init__.py
浏览文件 @
b480d92c
...
@@ -47,5 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
...
@@ -47,5 +47,6 @@ from .vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT
from
.distilled_vision_transformer
import
DeiT_tiny_patch16_224
,
DeiT_small_patch16_224
,
DeiT_base_patch16_224
,
DeiT_tiny_distilled_patch16_224
,
DeiT_small_distilled_patch16_224
,
DeiT_base_distilled_patch16_224
,
DeiT_base_patch16_384
,
DeiT_base_distilled_patch16_384
from
.distilled_vision_transformer
import
DeiT_tiny_patch16_224
,
DeiT_small_patch16_224
,
DeiT_base_patch16_224
,
DeiT_tiny_distilled_patch16_224
,
DeiT_small_distilled_patch16_224
,
DeiT_base_distilled_patch16_224
,
DeiT_base_patch16_384
,
DeiT_base_distilled_patch16_384
from
.distillation_models
import
ResNet50_vd_distill_MobileNetV3_large_x1_0
,
ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from
.distillation_models
import
ResNet50_vd_distill_MobileNetV3_large_x1_0
,
ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from
.repvgg
import
RepVGG_A0
,
RepVGG_A1
,
RepVGG_A2
,
RepVGG_B0
,
RepVGG_B1
,
RepVGG_B2
,
RepVGG_B3
,
RepVGG_B1g2
,
RepVGG_B1g4
,
RepVGG_B2g2
,
RepVGG_B2g4
,
RepVGG_B3g2
,
RepVGG_B3g4
from
.repvgg
import
RepVGG_A0
,
RepVGG_A1
,
RepVGG_A2
,
RepVGG_B0
,
RepVGG_B1
,
RepVGG_B2
,
RepVGG_B3
,
RepVGG_B1g2
,
RepVGG_B1g4
,
RepVGG_B2g2
,
RepVGG_B2g4
,
RepVGG_B3g2
,
RepVGG_B3g4
from
.swin_transformer
import
SwinTransformer_tiny_patch4_window7_224
,
SwinTransformer_small_patch4_window7_224
,
SwinTransformer_base_patch4_window7_224
,
SwinTransformer_base_patch4_window12_384
from
.mixnet
import
MixNet_S
,
MixNet_M
,
MixNet_L
from
.mixnet
import
MixNet_S
,
MixNet_M
,
MixNet_L
from
.rexnet
import
ReXNet_1_0
,
ReXNet_1_3
,
ReXNet_1_5
,
ReXNet_2_0
,
ReXNet_3_0
from
.rexnet
import
ReXNet_1_0
,
ReXNet_1_3
,
ReXNet_1_5
,
ReXNet_2_0
,
ReXNet_3_0
ppcls/modeling/architectures/swin_transformer.py
0 → 100644
浏览文件 @
b480d92c
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Reference: https://github.com/microsoft/Swin-Transformer
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
from
.vision_transformer
import
trunc_normal_
,
zeros_
,
ones_
,
to_2tuple
,
DropPath
,
Identity
class
Mlp
(
nn
.
Layer
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
reshape
(
[
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
])
windows
=
x
.
transpose
([
0
,
1
,
3
,
2
,
4
,
5
]).
reshape
(
[
-
1
,
window_size
,
window_size
,
C
])
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
reshape
(
[
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
])
x
=
x
.
transpose
([
0
,
1
,
3
,
2
,
4
,
5
]).
reshape
([
B
,
H
,
W
,
-
1
])
return
x
class
WindowAttention
(
nn
.
Layer
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
# define a parameter table of relative position bias
# 2*Wh-1 * 2*Ww-1, nH
self
.
relative_position_bias_table
=
self
.
create_parameter
(
shape
=
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"relative_position_bias_table"
,
self
.
relative_position_bias_table
)
# get pair-wise relative position index for each token inside the window
coords_h
=
paddle
.
arange
(
self
.
window_size
[
0
])
coords_w
=
paddle
.
arange
(
self
.
window_size
[
1
])
coords
=
paddle
.
stack
(
paddle
.
meshgrid
(
[
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
paddle
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
coords_flatten_1
=
coords_flatten
.
unsqueeze
(
axis
=
2
)
coords_flatten_2
=
coords_flatten
.
unsqueeze
(
axis
=
1
)
relative_coords
=
coords_flatten_1
-
coords_flatten_2
relative_coords
=
relative_coords
.
transpose
(
[
1
,
2
,
0
])
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias_attr
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
)
self
.
softmax
=
nn
.
Softmax
(
axis
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
[
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
]).
transpose
(
[
2
,
0
,
3
,
1
,
4
])
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
self
.
scale
attn
=
paddle
.
mm
(
q
,
k
.
transpose
([
0
,
1
,
3
,
2
]))
index
=
self
.
relative_position_index
.
reshape
([
-
1
])
relative_position_bias
=
paddle
.
index_select
(
self
.
relative_position_bias_table
,
index
)
relative_position_bias
=
relative_position_bias
.
reshape
([
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
])
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
transpose
(
[
2
,
0
,
1
])
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
reshape
([
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
])
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
reshape
([
-
1
,
self
.
num_heads
,
N
,
N
])
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
# x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
x
=
paddle
.
mm
(
attn
,
v
).
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
B_
,
N
,
C
])
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
extra_repr
(
self
):
return
"dim={}, window_size={}, num_heads={}"
.
format
(
self
.
dim
,
self
.
window_size
,
self
.
num_heads
)
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
class
SwinTransformerBlock
(
nn
.
Layer
):
r
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
H
,
W
=
self
.
input_resolution
img_mask
=
paddle
.
zeros
((
1
,
H
,
W
,
1
))
# 1 H W 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
reshape
(
[
-
1
,
self
.
window_size
*
self
.
window_size
])
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
huns
=
-
100.0
*
paddle
.
ones_like
(
attn_mask
)
attn_mask
=
huns
*
(
attn_mask
!=
0
).
astype
(
"float32"
)
else
:
attn_mask
=
None
self
.
register_buffer
(
"attn_mask"
,
attn_mask
)
def
forward
(
self
,
x
):
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
([
B
,
H
,
W
,
C
])
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
paddle
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
axis
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
reshape
(
[
-
1
,
self
.
window_size
*
self
.
window_size
,
C
])
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
reshape
(
[
-
1
,
self
.
window_size
,
self
.
window_size
,
C
])
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
paddle
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
axis
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
reshape
([
B
,
H
*
W
,
C
])
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
extra_repr
(
self
):
return
"dim={}, input_resolution={}, num_heads={}, window_size={}, shift_size={}, mlp_ratio={}"
.
format
(
self
.
dim
,
self
.
input_resolution
,
self
.
num_heads
,
self
.
window_size
,
self
.
shift_size
,
self
.
mlp_ratio
)
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Layer
):
r
""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias_attr
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
"x size ({}*{}) are not even."
.
format
(
H
,
W
)
x
=
x
.
reshape
([
B
,
H
,
W
,
C
])
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
paddle
.
concat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
reshape
([
B
,
-
1
,
4
*
C
])
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
):
return
"input_resolution={}, dim={}"
.
format
(
self
.
input_resolution
,
self
.
dim
)
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Layer
):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
,
use_checkpoint
=
False
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
nn
.
LayerList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)
])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
extra_repr
(
self
):
return
"dim={}, input_resolution={}, depth={}"
.
format
(
self
.
dim
,
self
.
input_resolution
,
self
.
depth
)
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Layer
):
""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Layer, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]
]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2D
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
"Input image size ({H}*{W}) doesn't match model ({}*{})."
.
format
(
H
,
W
,
self
.
img_size
[
0
],
self
.
img_size
[
1
])
x
=
self
.
proj
(
x
)
x
=
x
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Layer
):
""" Swin Transformer
A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
class_dim
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
use_checkpoint
=
False
,
**
kwargs
):
super
(
SwinTransformer
,
self
).
__init__
()
self
.
num_classes
=
num_classes
=
class_dim
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
self
.
create_parameter
(
shape
=
(
1
,
num_patches
,
embed_dim
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"absolute_pos_embed"
,
self
.
absolute_pos_embed
)
trunc_normal_
(
self
.
absolute_pos_embed
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
)).
tolist
()
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
LayerList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
use_checkpoint
=
use_checkpoint
)
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1D
(
1
)
self
.
head
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
if
self
.
num_classes
>
0
else
nn
.
Identity
()
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward_features
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x
.
transpose
([
0
,
2
,
1
]))
# B C 1
x
=
paddle
.
flatten
(
x
,
1
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
_
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
SwinTransformer_tiny_patch4_window7_224
(
**
args
):
model
=
SwinTransformer
(
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
drop_path_rate
=
0.2
,
**
args
)
return
model
def
SwinTransformer_small_patch4_window7_224
(
**
args
):
model
=
SwinTransformer
(
embed_dim
=
96
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
**
args
)
return
model
def
SwinTransformer_base_patch4_window7_224
(
**
args
):
model
=
SwinTransformer
(
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
7
,
drop_path_rate
=
0.5
,
**
args
)
return
model
def
SwinTransformer_base_patch4_window12_384
(
**
args
):
model
=
SwinTransformer
(
img_size
=
384
,
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
12
,
drop_path_rate
=
0.5
,
# NOTE: do not appear in offical code
**
args
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录