Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
9e2a8e86
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9e2a8e86
编写于
6月 10, 2021
作者:
G
gaotingquan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LeViT and Twins
上级
f4f09840
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
1195 addition
and
4 deletion
+1195
-4
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+3
-1
ppcls/arch/backbone/model_zoo/gvt.py
ppcls/arch/backbone/model_zoo/gvt.py
+659
-0
ppcls/arch/backbone/model_zoo/levit.py
ppcls/arch/backbone/model_zoo/levit.py
+515
-0
ppcls/arch/backbone/model_zoo/vision_transformer.py
ppcls/arch/backbone/model_zoo/vision_transformer.py
+18
-3
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
9e2a8e86
...
@@ -47,4 +47,6 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil
...
@@ -47,4 +47,6 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil
from
ppcls.arch.backbone.model_zoo.swin_transformer
import
SwinTransformer_tiny_patch4_window7_224
,
SwinTransformer_small_patch4_window7_224
,
SwinTransformer_base_patch4_window7_224
,
SwinTransformer_base_patch4_window12_384
,
SwinTransformer_large_patch4_window7_224
,
SwinTransformer_large_patch4_window12_384
from
ppcls.arch.backbone.model_zoo.swin_transformer
import
SwinTransformer_tiny_patch4_window7_224
,
SwinTransformer_small_patch4_window7_224
,
SwinTransformer_base_patch4_window7_224
,
SwinTransformer_base_patch4_window12_384
,
SwinTransformer_large_patch4_window7_224
,
SwinTransformer_large_patch4_window12_384
from
ppcls.arch.backbone.model_zoo.mixnet
import
MixNet_S
,
MixNet_M
,
MixNet_L
from
ppcls.arch.backbone.model_zoo.mixnet
import
MixNet_S
,
MixNet_M
,
MixNet_L
from
ppcls.arch.backbone.model_zoo.rexnet
import
ReXNet_1_0
,
ReXNet_1_3
,
ReXNet_1_5
,
ReXNet_2_0
,
ReXNet_3_0
from
ppcls.arch.backbone.model_zoo.rexnet
import
ReXNet_1_0
,
ReXNet_1_3
,
ReXNet_1_5
,
ReXNet_2_0
,
ReXNet_3_0
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.model_zoo.gvt
import
pcpvt_small
,
pcpvt_base
,
pcpvt_large
,
alt_gvt_small
,
alt_gvt_base
,
alt_gvt_large
from
ppcls.arch.backbone.model_zoo.levit
import
LeViT_128S
,
LeViT_128
,
LeViT_192
,
LeViT_256
,
LeViT_384
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
\ No newline at end of file
ppcls/arch/backbone/model_zoo/gvt.py
0 → 100644
浏览文件 @
9e2a8e86
此差异已折叠。
点击以展开。
ppcls/arch/backbone/model_zoo/levit.py
0 → 100644
浏览文件 @
9e2a8e86
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
import
math
import
warnings
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
from
paddle.regularizer
import
L2Decay
from
.vision_transformer
import
trunc_normal_
,
zeros_
,
ones_
,
Identity
__all__
=
[
"LeViT_128S"
,
"LeViT_128"
,
"LeViT_192"
,
"LeViT_256"
,
"LeViT_384"
]
def
cal_attention_biases
(
attention_biases
,
attention_bias_idxs
):
gather_list
=
[]
attention_bias_t
=
paddle
.
transpose
(
attention_biases
,
(
1
,
0
))
for
idx
in
attention_bias_idxs
:
gather
=
paddle
.
gather
(
attention_bias_t
,
idx
)
gather_list
.
append
(
gather
)
shape0
,
shape1
=
attention_bias_idxs
.
shape
return
paddle
.
transpose
(
paddle
.
concat
(
gather_list
),
(
1
,
0
)).
reshape
(
(
0
,
shape0
,
shape1
))
class
Conv2d_BN
(
nn
.
Sequential
):
def
__init__
(
self
,
a
,
b
,
ks
=
1
,
stride
=
1
,
pad
=
0
,
dilation
=
1
,
groups
=
1
,
bn_weight_init
=
1
,
resolution
=-
10000
):
super
().
__init__
()
self
.
add_sublayer
(
'c'
,
nn
.
Conv2D
(
a
,
b
,
ks
,
stride
,
pad
,
dilation
,
groups
,
bias_attr
=
False
))
bn
=
nn
.
BatchNorm2D
(
b
)
ones_
(
bn
.
weight
)
zeros_
(
bn
.
bias
)
self
.
add_sublayer
(
'bn'
,
bn
)
class
Linear_BN
(
nn
.
Sequential
):
def
__init__
(
self
,
a
,
b
,
bn_weight_init
=
1
):
super
().
__init__
()
self
.
add_sublayer
(
'c'
,
nn
.
Linear
(
a
,
b
,
bias_attr
=
False
))
bn
=
nn
.
BatchNorm1D
(
b
)
ones_
(
bn
.
weight
)
zeros_
(
bn
.
bias
)
self
.
add_sublayer
(
'bn'
,
bn
)
def
forward
(
self
,
x
):
l
,
bn
=
self
.
_sub_layers
.
values
()
x
=
l
(
x
)
return
paddle
.
reshape
(
bn
(
x
.
flatten
(
0
,
1
)),
x
.
shape
)
class
BN_Linear
(
nn
.
Sequential
):
def
__init__
(
self
,
a
,
b
,
bias
=
True
,
std
=
0.02
):
super
().
__init__
()
self
.
add_sublayer
(
'bn'
,
nn
.
BatchNorm1D
(
a
))
l
=
nn
.
Linear
(
a
,
b
,
bias_attr
=
bias
)
trunc_normal_
(
l
.
weight
)
if
bias
:
zeros_
(
l
.
bias
)
self
.
add_sublayer
(
'l'
,
l
)
def
b16
(
n
,
activation
,
resolution
=
224
):
return
nn
.
Sequential
(
Conv2d_BN
(
3
,
n
//
8
,
3
,
2
,
1
,
resolution
=
resolution
),
activation
(),
Conv2d_BN
(
n
//
8
,
n
//
4
,
3
,
2
,
1
,
resolution
=
resolution
//
2
),
activation
(),
Conv2d_BN
(
n
//
4
,
n
//
2
,
3
,
2
,
1
,
resolution
=
resolution
//
4
),
activation
(),
Conv2d_BN
(
n
//
2
,
n
,
3
,
2
,
1
,
resolution
=
resolution
//
8
))
class
Residual
(
nn
.
Layer
):
def
__init__
(
self
,
m
,
drop
):
super
().
__init__
()
self
.
m
=
m
self
.
drop
=
drop
def
forward
(
self
,
x
):
if
self
.
training
and
self
.
drop
>
0
:
return
x
+
self
.
m
(
x
)
*
paddle
.
rand
(
x
.
size
(
0
),
1
,
1
,
device
=
x
.
device
).
ge_
(
self
.
drop
).
div
(
1
-
self
.
drop
).
detach
()
else
:
return
x
+
self
.
m
(
x
)
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
key_dim
,
num_heads
=
8
,
attn_ratio
=
4
,
activation
=
None
,
resolution
=
14
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
scale
=
key_dim
**-
0.5
self
.
key_dim
=
key_dim
self
.
nh_kd
=
nh_kd
=
key_dim
*
num_heads
self
.
d
=
int
(
attn_ratio
*
key_dim
)
self
.
dh
=
int
(
attn_ratio
*
key_dim
)
*
num_heads
self
.
attn_ratio
=
attn_ratio
self
.
h
=
self
.
dh
+
nh_kd
*
2
self
.
qkv
=
Linear_BN
(
dim
,
self
.
h
)
self
.
proj
=
nn
.
Sequential
(
activation
(),
Linear_BN
(
self
.
dh
,
dim
,
bn_weight_init
=
0
))
points
=
list
(
itertools
.
product
(
range
(
resolution
),
range
(
resolution
)))
N
=
len
(
points
)
attention_offsets
=
{}
idxs
=
[]
for
p1
in
points
:
for
p2
in
points
:
offset
=
(
abs
(
p1
[
0
]
-
p2
[
0
]),
abs
(
p1
[
1
]
-
p2
[
1
]))
if
offset
not
in
attention_offsets
:
attention_offsets
[
offset
]
=
len
(
attention_offsets
)
idxs
.
append
(
attention_offsets
[
offset
])
self
.
attention_biases
=
self
.
create_parameter
(
shape
=
(
num_heads
,
len
(
attention_offsets
)),
default_initializer
=
zeros_
,
attr
=
paddle
.
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
tensor_idxs
=
paddle
.
to_tensor
(
idxs
,
dtype
=
'int64'
)
self
.
register_buffer
(
'attention_bias_idxs'
,
paddle
.
reshape
(
tensor_idxs
,
[
N
,
N
]))
@
paddle
.
no_grad
()
def
train
(
self
,
mode
=
True
):
if
mode
:
super
().
train
()
else
:
super
().
eval
()
if
mode
and
hasattr
(
self
,
'ab'
):
del
self
.
ab
else
:
self
.
ab
=
cal_attention_biases
(
self
.
attention_biases
,
self
.
attention_bias_idxs
)
def
forward
(
self
,
x
):
self
.
training
=
True
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
qkv
=
paddle
.
reshape
(
qkv
,
[
B
,
N
,
self
.
num_heads
,
self
.
h
//
self
.
num_heads
])
q
,
k
,
v
=
paddle
.
split
(
qkv
,
[
self
.
key_dim
,
self
.
key_dim
,
self
.
d
],
axis
=
3
)
q
=
paddle
.
transpose
(
q
,
perm
=
[
0
,
2
,
1
,
3
])
k
=
paddle
.
transpose
(
k
,
perm
=
[
0
,
2
,
1
,
3
])
v
=
paddle
.
transpose
(
v
,
perm
=
[
0
,
2
,
1
,
3
])
k_transpose
=
paddle
.
transpose
(
k
,
perm
=
[
0
,
1
,
3
,
2
])
if
self
.
training
:
attention_biases
=
cal_attention_biases
(
self
.
attention_biases
,
self
.
attention_bias_idxs
)
else
:
attention_biases
=
self
.
ab
attn
=
((
q
@
k_transpose
)
*
self
.
scale
+
attention_biases
)
attn
=
F
.
softmax
(
attn
)
x
=
paddle
.
transpose
(
attn
@
v
,
perm
=
[
0
,
2
,
1
,
3
])
x
=
paddle
.
reshape
(
x
,
[
B
,
N
,
self
.
dh
])
x
=
self
.
proj
(
x
)
return
x
class
Subsample
(
nn
.
Layer
):
def
__init__
(
self
,
stride
,
resolution
):
super
().
__init__
()
self
.
stride
=
stride
self
.
resolution
=
resolution
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
x
=
paddle
.
reshape
(
x
,
[
B
,
self
.
resolution
,
self
.
resolution
,
C
])[:,
::
self
.
stride
,
::
self
.
stride
]
x
=
paddle
.
reshape
(
x
,
[
B
,
-
1
,
C
])
return
x
class
AttentionSubsample
(
nn
.
Layer
):
def
__init__
(
self
,
in_dim
,
out_dim
,
key_dim
,
num_heads
=
8
,
attn_ratio
=
2
,
activation
=
None
,
stride
=
2
,
resolution
=
14
,
resolution_
=
7
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
scale
=
key_dim
**-
0.5
self
.
key_dim
=
key_dim
self
.
nh_kd
=
nh_kd
=
key_dim
*
num_heads
self
.
d
=
int
(
attn_ratio
*
key_dim
)
self
.
dh
=
int
(
attn_ratio
*
key_dim
)
*
self
.
num_heads
self
.
attn_ratio
=
attn_ratio
self
.
resolution_
=
resolution_
self
.
resolution_2
=
resolution_
**
2
self
.
training
=
True
h
=
self
.
dh
+
nh_kd
self
.
kv
=
Linear_BN
(
in_dim
,
h
)
self
.
q
=
nn
.
Sequential
(
Subsample
(
stride
,
resolution
),
Linear_BN
(
in_dim
,
nh_kd
))
self
.
proj
=
nn
.
Sequential
(
activation
(),
Linear_BN
(
self
.
dh
,
out_dim
))
self
.
stride
=
stride
self
.
resolution
=
resolution
points
=
list
(
itertools
.
product
(
range
(
resolution
),
range
(
resolution
)))
points_
=
list
(
itertools
.
product
(
range
(
resolution_
),
range
(
resolution_
)))
N
=
len
(
points
)
N_
=
len
(
points_
)
attention_offsets
=
{}
idxs
=
[]
i
=
0
j
=
0
for
p1
in
points_
:
i
+=
1
for
p2
in
points
:
j
+=
1
size
=
1
offset
=
(
abs
(
p1
[
0
]
*
stride
-
p2
[
0
]
+
(
size
-
1
)
/
2
),
abs
(
p1
[
1
]
*
stride
-
p2
[
1
]
+
(
size
-
1
)
/
2
))
if
offset
not
in
attention_offsets
:
attention_offsets
[
offset
]
=
len
(
attention_offsets
)
idxs
.
append
(
attention_offsets
[
offset
])
self
.
attention_biases
=
self
.
create_parameter
(
shape
=
(
num_heads
,
len
(
attention_offsets
)),
default_initializer
=
zeros_
,
attr
=
paddle
.
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
tensor_idxs_
=
paddle
.
to_tensor
(
idxs
,
dtype
=
'int64'
)
self
.
register_buffer
(
'attention_bias_idxs'
,
paddle
.
reshape
(
tensor_idxs_
,
[
N_
,
N
]))
@
paddle
.
no_grad
()
def
train
(
self
,
mode
=
True
):
if
mode
:
super
().
train
()
else
:
super
().
eval
()
if
mode
and
hasattr
(
self
,
'ab'
):
del
self
.
ab
else
:
self
.
ab
=
cal_attention_biases
(
self
.
attention_biases
,
self
.
attention_bias_idxs
)
def
forward
(
self
,
x
):
self
.
training
=
True
B
,
N
,
C
=
x
.
shape
kv
=
self
.
kv
(
x
)
kv
=
paddle
.
reshape
(
kv
,
[
B
,
N
,
self
.
num_heads
,
-
1
])
k
,
v
=
paddle
.
split
(
kv
,
[
self
.
key_dim
,
self
.
d
],
axis
=
3
)
k
=
paddle
.
transpose
(
k
,
perm
=
[
0
,
2
,
1
,
3
])
# BHNC
v
=
paddle
.
transpose
(
v
,
perm
=
[
0
,
2
,
1
,
3
])
q
=
paddle
.
reshape
(
self
.
q
(
x
),
[
B
,
self
.
resolution_2
,
self
.
num_heads
,
self
.
key_dim
])
q
=
paddle
.
transpose
(
q
,
perm
=
[
0
,
2
,
1
,
3
])
if
self
.
training
:
attention_biases
=
cal_attention_biases
(
self
.
attention_biases
,
self
.
attention_bias_idxs
)
else
:
attention_biases
=
self
.
ab
attn
=
(
q
@
paddle
.
transpose
(
k
,
perm
=
[
0
,
1
,
3
,
2
]))
*
self
.
scale
+
attention_biases
attn
=
F
.
softmax
(
attn
)
x
=
paddle
.
reshape
(
paddle
.
transpose
(
(
attn
@
v
),
perm
=
[
0
,
2
,
1
,
3
]),
[
B
,
-
1
,
self
.
dh
])
x
=
self
.
proj
(
x
)
return
x
class
LeViT
(
nn
.
Layer
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
class_dim
=
1000
,
embed_dim
=
[
192
],
key_dim
=
[
64
],
depth
=
[
12
],
num_heads
=
[
3
],
attn_ratio
=
[
2
],
mlp_ratio
=
[
2
],
hybrid_backbone
=
None
,
down_ops
=
[],
attention_activation
=
nn
.
Hardswish
,
mlp_activation
=
nn
.
Hardswish
,
distillation
=
True
,
drop_path
=
0
):
super
().
__init__
()
self
.
class_dim
=
class_dim
self
.
num_features
=
embed_dim
[
-
1
]
self
.
embed_dim
=
embed_dim
self
.
distillation
=
distillation
self
.
patch_embed
=
hybrid_backbone
self
.
blocks
=
[]
down_ops
.
append
([
''
])
resolution
=
img_size
//
patch_size
for
i
,
(
ed
,
kd
,
dpth
,
nh
,
ar
,
mr
,
do
)
in
enumerate
(
zip
(
embed_dim
,
key_dim
,
depth
,
num_heads
,
attn_ratio
,
mlp_ratio
,
down_ops
)):
for
_
in
range
(
dpth
):
self
.
blocks
.
append
(
Residual
(
Attention
(
ed
,
kd
,
nh
,
attn_ratio
=
ar
,
activation
=
attention_activation
,
resolution
=
resolution
,
),
drop_path
))
if
mr
>
0
:
h
=
int
(
ed
*
mr
)
self
.
blocks
.
append
(
Residual
(
nn
.
Sequential
(
Linear_BN
(
ed
,
h
),
mlp_activation
(),
Linear_BN
(
h
,
ed
,
bn_weight_init
=
0
),
),
drop_path
))
if
do
[
0
]
==
'Subsample'
:
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_
=
(
resolution
-
1
)
//
do
[
5
]
+
1
self
.
blocks
.
append
(
AttentionSubsample
(
*
embed_dim
[
i
:
i
+
2
],
key_dim
=
do
[
1
],
num_heads
=
do
[
2
],
attn_ratio
=
do
[
3
],
activation
=
attention_activation
,
stride
=
do
[
5
],
resolution
=
resolution
,
resolution_
=
resolution_
))
resolution
=
resolution_
if
do
[
4
]
>
0
:
# mlp_ratio
h
=
int
(
embed_dim
[
i
+
1
]
*
do
[
4
])
self
.
blocks
.
append
(
Residual
(
nn
.
Sequential
(
Linear_BN
(
embed_dim
[
i
+
1
],
h
),
mlp_activation
(),
Linear_BN
(
h
,
embed_dim
[
i
+
1
],
bn_weight_init
=
0
),
),
drop_path
))
self
.
blocks
=
nn
.
Sequential
(
*
self
.
blocks
)
# Classifier head
self
.
head
=
BN_Linear
(
embed_dim
[
-
1
],
class_dim
)
if
class_dim
>
0
else
Identity
()
if
distillation
:
self
.
head_dist
=
BN_Linear
(
embed_dim
[
-
1
],
class_dim
)
if
class_dim
>
0
else
Identity
()
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
x
=
x
.
flatten
(
2
)
x
=
paddle
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
])
x
=
self
.
blocks
(
x
)
x
=
x
.
mean
(
1
)
if
self
.
distillation
:
x
=
self
.
head
(
x
),
self
.
head_dist
(
x
)
if
not
self
.
training
:
x
=
(
x
[
0
]
+
x
[
1
])
/
2
else
:
x
=
self
.
head
(
x
)
return
x
def
model_factory
(
C
,
D
,
X
,
N
,
drop_path
,
class_dim
,
distillation
):
embed_dim
=
[
int
(
x
)
for
x
in
C
.
split
(
'_'
)]
num_heads
=
[
int
(
x
)
for
x
in
N
.
split
(
'_'
)]
depth
=
[
int
(
x
)
for
x
in
X
.
split
(
'_'
)]
act
=
nn
.
Hardswish
model
=
LeViT
(
patch_size
=
16
,
embed_dim
=
embed_dim
,
num_heads
=
num_heads
,
key_dim
=
[
D
]
*
3
,
depth
=
depth
,
attn_ratio
=
[
2
,
2
,
2
],
mlp_ratio
=
[
2
,
2
,
2
],
down_ops
=
[
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
[
'Subsample'
,
D
,
embed_dim
[
0
]
//
D
,
4
,
2
,
2
],
[
'Subsample'
,
D
,
embed_dim
[
1
]
//
D
,
4
,
2
,
2
],
],
attention_activation
=
act
,
mlp_activation
=
act
,
hybrid_backbone
=
b16
(
embed_dim
[
0
],
activation
=
act
),
class_dim
=
class_dim
,
drop_path
=
drop_path
,
distillation
=
distillation
)
return
model
specification
=
{
'LeViT_128S'
:
{
'C'
:
'128_256_384'
,
'D'
:
16
,
'N'
:
'4_6_8'
,
'X'
:
'2_3_4'
,
'drop_path'
:
0
},
'LeViT_128'
:
{
'C'
:
'128_256_384'
,
'D'
:
16
,
'N'
:
'4_8_12'
,
'X'
:
'4_4_4'
,
'drop_path'
:
0
},
'LeViT_192'
:
{
'C'
:
'192_288_384'
,
'D'
:
32
,
'N'
:
'3_5_6'
,
'X'
:
'4_4_4'
,
'drop_path'
:
0
},
'LeViT_256'
:
{
'C'
:
'256_384_512'
,
'D'
:
32
,
'N'
:
'4_6_8'
,
'X'
:
'4_4_4'
,
'drop_path'
:
0
},
'LeViT_384'
:
{
'C'
:
'384_512_768'
,
'D'
:
32
,
'N'
:
'6_9_12'
,
'X'
:
'4_4_4'
,
'drop_path'
:
0.1
},
}
def
LeViT_128S
(
class_dim
=
1000
,
distillation
=
True
,
pretrained
=
False
):
return
model_factory
(
**
specification
[
'LeViT_128S'
],
class_dim
=
class_dim
,
distillation
=
distillation
)
def
LeViT_128
(
class_dim
=
1000
,
distillation
=
True
):
return
model_factory
(
**
specification
[
'LeViT_128'
],
class_dim
=
class_dim
,
distillation
=
distillation
)
def
LeViT_192
(
class_dim
=
1000
,
distillation
=
True
):
return
model_factory
(
**
specification
[
'LeViT_192'
],
class_dim
=
class_dim
,
distillation
=
distillation
)
def
LeViT_256
(
class_dim
=
1000
,
distillation
=
False
):
return
model_factory
(
**
specification
[
'LeViT_256'
],
class_dim
=
class_dim
,
distillation
=
distillation
)
def
LeViT_384
(
class_dim
=
1000
,
distillation
=
True
):
return
model_factory
(
**
specification
[
'LeViT_384'
],
class_dim
=
class_dim
,
distillation
=
distillation
)
ppcls/arch/backbone/model_zoo/vision_transformer.py
浏览文件 @
9e2a8e86
...
@@ -12,10 +12,12 @@
...
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
collections
import
Callable
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
__all__
=
[
__all__
=
[
"VisionTransformer"
,
"ViT_small_patch16_224"
,
"ViT_base_patch16_224"
,
"VisionTransformer"
,
"ViT_small_patch16_224"
,
"ViT_base_patch16_224"
,
...
@@ -25,6 +27,7 @@ __all__ = [
...
@@ -25,6 +27,7 @@ __all__ = [
]
]
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
normal_
=
Normal
zeros_
=
Constant
(
value
=
0.
)
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
ones_
=
Constant
(
value
=
1.
)
...
@@ -141,7 +144,13 @@ class Block(nn.Layer):
...
@@ -141,7 +144,13 @@ class Block(nn.Layer):
norm_layer
=
'nn.LayerNorm'
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-5
):
epsilon
=
1e-5
):
super
().
__init__
()
super
().
__init__
()
self
.
norm1
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
if
isinstance
(
norm_layer
,
str
):
self
.
norm1
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm1
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
self
.
attn
=
Attention
(
self
.
attn
=
Attention
(
dim
,
dim
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
...
@@ -151,7 +160,13 @@ class Block(nn.Layer):
...
@@ -151,7 +160,13 @@ class Block(nn.Layer):
proj_drop
=
drop
)
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
self
.
norm2
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
if
isinstance
(
norm_layer
,
str
):
self
.
norm2
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm2
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
hidden_features
=
mlp_hidden_dim
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录