Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
d7a11275
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d7a11275
编写于
9月 16, 2022
作者:
Y
Yang Nie
提交者:
Tingquan Gao
4月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CvT
上级
e4740c84
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
876 addition
and
2 deletion
+876
-2
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/cvt.py
ppcls/arch/backbone/model_zoo/cvt.py
+657
-0
ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml
ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml
+162
-0
ppcls/data/preprocess/ops/operators.py
ppcls/data/preprocess/ops/operators.py
+2
-2
test_tipc/configs/CvT/cvt_13_224x224_train_infer_python.txt
test_tipc/configs/CvT/cvt_13_224x224_train_infer_python.txt
+54
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
d7a11275
...
...
@@ -75,6 +75,7 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p
from
.model_zoo.convnext
import
ConvNeXt_tiny
,
ConvNeXt_small
,
ConvNeXt_base_224
,
ConvNeXt_base_384
,
ConvNeXt_large_224
,
ConvNeXt_large_384
from
.model_zoo.nextvit
import
NextViT_small_224
,
NextViT_base_224
,
NextViT_large_224
,
NextViT_small_384
,
NextViT_base_384
,
NextViT_large_384
from
.model_zoo.cae
import
cae_base_patch16_224
,
cae_large_patch16_224
from
.model_zoo.cvt
import
cvt_13_224x224
,
cvt_13_384x384
,
cvt_21_224x224
,
cvt_21_384x384
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.resnet_variant
import
ResNet50_adaptive_max_pool2d
...
...
ppcls/arch/backbone/model_zoo/cvt.py
0 → 100644
浏览文件 @
d7a11275
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Code was heavily based on https://github.com/microsoft/CvT
# reference: https://arxiv.org/abs/2103.15808
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
XavierUniform
,
TruncatedNormal
,
Constant
from
....utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"cvt_13_224x224"
:
""
,
# TODO
"cvt_13_384x384"
:
""
,
# TODO
"cvt_21_224x224"
:
""
,
# TODO
"cvt_21_384x384"
:
""
,
# TODO
}
__all__
=
list
(
MODEL_URLS
.
keys
())
xavier_uniform_
=
XavierUniform
()
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
drop_path
(
x
,
drop_prob
=
0.
,
training
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
paddle
.
to_tensor
(
1
-
drop_prob
)
shape
=
(
paddle
.
shape
(
x
)[
0
],
)
+
(
1
,
)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
,
dtype
=
x
.
dtype
)
random_tensor
=
paddle
.
floor
(
random_tensor
)
# binarize
output
=
x
.
divide
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Layer
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
def
extra_repr
(
self
):
return
f
'drop_prob=
{
self
.
drop_prob
:.
3
f
}
'
def
rearrange
(
x
,
pattern
,
**
axes_lengths
):
if
'b (h w) c -> b c h w'
==
pattern
:
b
,
_
,
c
=
x
.
shape
h
,
w
=
axes_lengths
.
pop
(
'h'
,
-
1
),
axes_lengths
.
pop
(
'w'
,
-
1
)
return
x
.
transpose
([
0
,
2
,
1
]).
reshape
([
b
,
c
,
h
,
w
])
if
'b c h w -> b (h w) c'
==
pattern
:
b
,
c
,
h
,
w
=
x
.
shape
return
x
.
reshape
([
b
,
c
,
h
*
w
]).
transpose
([
0
,
2
,
1
])
if
'b t (h d) -> b h t d'
==
pattern
:
b
,
t
,
h_d
=
x
.
shape
h
=
axes_lengths
[
'h'
]
return
x
.
reshape
([
b
,
t
,
h
,
h_d
//
h
]).
transpose
([
0
,
2
,
1
,
3
])
if
'b h t d -> b t (h d)'
==
pattern
:
b
,
h
,
t
,
d
=
x
.
shape
return
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
b
,
t
,
h
*
d
])
raise
NotImplementedError
(
f
"Rearrangement '
{
pattern
}
' has not been implemented."
)
class
Rearrange
(
nn
.
Layer
):
def
__init__
(
self
,
pattern
,
**
axes_lengths
):
super
().
__init__
()
self
.
pattern
=
pattern
self
.
axes_lengths
=
axes_lengths
def
forward
(
self
,
x
):
return
rearrange
(
x
,
self
.
pattern
,
**
self
.
axes_lengths
)
def
extra_repr
(
self
):
return
self
.
pattern
class
QuickGELU
(
nn
.
Layer
):
def
forward
(
self
,
x
):
return
x
*
nn
.
functional
.
sigmoid
(
1.702
*
x
)
class
Mlp
(
nn
.
Layer
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
dim_in
,
dim_out
,
num_heads
,
qkv_bias
=
False
,
attn_drop
=
0.
,
proj_drop
=
0.
,
method
=
'dw_bn'
,
kernel_size
=
3
,
stride_kv
=
1
,
stride_q
=
1
,
padding_kv
=
1
,
padding_q
=
1
,
with_cls_token
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
stride_kv
=
stride_kv
self
.
stride_q
=
stride_q
self
.
dim
=
dim_out
self
.
num_heads
=
num_heads
# head_dim = self.qkv_dim // num_heads
self
.
scale
=
dim_out
**-
0.5
self
.
with_cls_token
=
with_cls_token
self
.
conv_proj_q
=
self
.
_build_projection
(
dim_in
,
dim_out
,
kernel_size
,
padding_q
,
stride_q
,
'linear'
if
method
==
'avg'
else
method
)
self
.
conv_proj_k
=
self
.
_build_projection
(
dim_in
,
dim_out
,
kernel_size
,
padding_kv
,
stride_kv
,
method
)
self
.
conv_proj_v
=
self
.
_build_projection
(
dim_in
,
dim_out
,
kernel_size
,
padding_kv
,
stride_kv
,
method
)
self
.
proj_q
=
nn
.
Linear
(
dim_in
,
dim_out
,
bias_attr
=
qkv_bias
)
self
.
proj_k
=
nn
.
Linear
(
dim_in
,
dim_out
,
bias_attr
=
qkv_bias
)
self
.
proj_v
=
nn
.
Linear
(
dim_in
,
dim_out
,
bias_attr
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim_out
,
dim_out
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
_build_projection
(
self
,
dim_in
,
dim_out
,
kernel_size
,
padding
,
stride
,
method
):
if
method
==
'dw_bn'
:
proj
=
nn
.
Sequential
(
(
'conv'
,
nn
.
Conv2D
(
dim_in
,
dim_in
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
False
,
groups
=
dim_in
)),
(
'bn'
,
nn
.
BatchNorm2D
(
dim_in
)),
(
'rearrage'
,
Rearrange
(
'b c h w -> b (h w) c'
)))
elif
method
==
'avg'
:
proj
=
nn
.
Sequential
(
(
'avg'
,
nn
.
AvgPool2D
(
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
ceil_mode
=
True
)),
(
'rearrage'
,
Rearrange
(
'b c h w -> b (h w) c'
)))
elif
method
==
'linear'
:
proj
=
None
else
:
raise
ValueError
(
'Unknown method ({})'
.
format
(
method
))
return
proj
def
forward_conv
(
self
,
x
,
h
,
w
):
if
self
.
with_cls_token
:
cls_token
,
x
=
paddle
.
split
(
x
,
[
1
,
h
*
w
],
1
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
)
if
self
.
conv_proj_q
is
not
None
:
q
=
self
.
conv_proj_q
(
x
)
else
:
q
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
conv_proj_k
is
not
None
:
k
=
self
.
conv_proj_k
(
x
)
else
:
k
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
conv_proj_v
is
not
None
:
v
=
self
.
conv_proj_v
(
x
)
else
:
v
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
with_cls_token
:
q
=
paddle
.
concat
((
cls_token
,
q
),
axis
=
1
)
k
=
paddle
.
concat
((
cls_token
,
k
),
axis
=
1
)
v
=
paddle
.
concat
((
cls_token
,
v
),
axis
=
1
)
return
q
,
k
,
v
def
forward
(
self
,
x
,
h
,
w
):
if
(
self
.
conv_proj_q
is
not
None
or
self
.
conv_proj_k
is
not
None
or
self
.
conv_proj_v
is
not
None
):
q
,
k
,
v
=
self
.
forward_conv
(
x
,
h
,
w
)
q
=
rearrange
(
self
.
proj_q
(
q
),
'b t (h d) -> b h t d'
,
h
=
self
.
num_heads
)
k
=
rearrange
(
self
.
proj_k
(
k
),
'b t (h d) -> b h t d'
,
h
=
self
.
num_heads
)
v
=
rearrange
(
self
.
proj_v
(
v
),
'b t (h d) -> b h t d'
,
h
=
self
.
num_heads
)
attn_score
=
(
q
@
k
.
transpose
([
0
,
1
,
3
,
2
]))
*
self
.
scale
attn
=
nn
.
functional
.
softmax
(
attn_score
,
axis
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
attn
@
v
x
=
rearrange
(
x
,
'b h t d -> b t (h d)'
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
dim_in
,
dim_out
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
**
kwargs
):
super
().
__init__
()
self
.
with_cls_token
=
kwargs
[
'with_cls_token'
]
self
.
norm1
=
norm_layer
(
dim_in
)
self
.
attn
=
Attention
(
dim_in
,
dim_out
,
num_heads
,
qkv_bias
,
attn_drop
,
drop
,
**
kwargs
)
self
.
drop_path
=
DropPath
(
drop_path
)
\
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim_out
)
dim_mlp_hidden
=
int
(
dim_out
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim_out
,
hidden_features
=
dim_mlp_hidden
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
,
h
,
w
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
h
,
w
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
ConvEmbed
(
nn
.
Layer
):
def
__init__
(
self
,
patch_size
=
7
,
in_chans
=
3
,
embed_dim
=
64
,
stride
=
4
,
padding
=
2
,
norm_layer
=
None
):
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
proj
=
nn
.
Conv2D
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
padding
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
None
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
B
,
C
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
norm
:
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
H
,
w
=
W
)
return
x
class
VisionTransformer
(
nn
.
Layer
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
patch_size
=
16
,
patch_stride
=
16
,
patch_padding
=
0
,
in_chans
=
3
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
init
=
'trunc_norm'
,
**
kwargs
):
super
().
__init__
()
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
self
.
rearrage
=
None
self
.
patch_embed
=
ConvEmbed
(
# img_size=img_size,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
stride
=
patch_stride
,
padding
=
patch_padding
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
)
with_cls_token
=
kwargs
[
'with_cls_token'
]
if
with_cls_token
:
self
.
cls_token
=
self
.
create_parameter
(
shape
=
[
1
,
1
,
embed_dim
],
default_initializer
=
trunc_normal_
)
else
:
self
.
cls_token
=
None
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
paddle
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
blocks
=
[]
for
j
in
range
(
depth
):
blocks
.
append
(
Block
(
dim_in
=
embed_dim
,
dim_out
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
j
],
act_layer
=
act_layer
,
norm_layer
=
norm_layer
,
**
kwargs
))
self
.
blocks
=
nn
.
LayerList
(
blocks
)
if
init
==
'xavier'
:
self
.
apply
(
self
.
_init_weights_xavier
)
else
:
self
.
apply
(
self
.
_init_weights_trunc_normal
)
def
_init_weights_trunc_normal
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
(
nn
.
LayerNorm
,
nn
.
BatchNorm2D
)):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
_init_weights_xavier
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
xavier_uniform_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
(
nn
.
LayerNorm
,
nn
.
BatchNorm2D
)):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
B
,
C
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
cls_tokens
=
None
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
([
B
,
-
1
,
-
1
])
x
=
paddle
.
concat
((
cls_tokens
,
x
),
axis
=
1
)
x
=
self
.
pos_drop
(
x
)
for
i
,
blk
in
enumerate
(
self
.
blocks
):
x
=
blk
(
x
,
H
,
W
)
if
self
.
cls_token
is
not
None
:
cls_tokens
,
x
=
paddle
.
split
(
x
,
[
1
,
H
*
W
],
1
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
H
,
w
=
W
)
return
x
,
cls_tokens
class
ConvolutionalVisionTransformer
(
nn
.
Layer
):
def
__init__
(
self
,
in_chans
=
3
,
class_num
=
1000
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
init
=
'trunc_norm'
,
spec
=
None
):
super
().
__init__
()
self
.
class_num
=
class_num
self
.
num_stages
=
spec
[
'NUM_STAGES'
]
for
i
in
range
(
self
.
num_stages
):
kwargs
=
{
'patch_size'
:
spec
[
'PATCH_SIZE'
][
i
],
'patch_stride'
:
spec
[
'PATCH_STRIDE'
][
i
],
'patch_padding'
:
spec
[
'PATCH_PADDING'
][
i
],
'embed_dim'
:
spec
[
'DIM_EMBED'
][
i
],
'depth'
:
spec
[
'DEPTH'
][
i
],
'num_heads'
:
spec
[
'NUM_HEADS'
][
i
],
'mlp_ratio'
:
spec
[
'MLP_RATIO'
][
i
],
'qkv_bias'
:
spec
[
'QKV_BIAS'
][
i
],
'drop_rate'
:
spec
[
'DROP_RATE'
][
i
],
'attn_drop_rate'
:
spec
[
'ATTN_DROP_RATE'
][
i
],
'drop_path_rate'
:
spec
[
'DROP_PATH_RATE'
][
i
],
'with_cls_token'
:
spec
[
'CLS_TOKEN'
][
i
],
'method'
:
spec
[
'QKV_PROJ_METHOD'
][
i
],
'kernel_size'
:
spec
[
'KERNEL_QKV'
][
i
],
'padding_q'
:
spec
[
'PADDING_Q'
][
i
],
'padding_kv'
:
spec
[
'PADDING_KV'
][
i
],
'stride_kv'
:
spec
[
'STRIDE_KV'
][
i
],
'stride_q'
:
spec
[
'STRIDE_Q'
][
i
],
}
stage
=
VisionTransformer
(
in_chans
=
in_chans
,
init
=
init
,
act_layer
=
act_layer
,
norm_layer
=
norm_layer
,
**
kwargs
)
setattr
(
self
,
f
'stage
{
i
}
'
,
stage
)
in_chans
=
spec
[
'DIM_EMBED'
][
i
]
dim_embed
=
spec
[
'DIM_EMBED'
][
-
1
]
self
.
norm
=
norm_layer
(
dim_embed
)
self
.
cls_token
=
spec
[
'CLS_TOKEN'
][
-
1
]
# Classifier head
self
.
head
=
nn
.
Linear
(
dim_embed
,
class_num
)
if
class_num
>
0
else
nn
.
Identity
()
trunc_normal_
(
self
.
head
.
weight
)
bound
=
1
/
dim_embed
**
.
5
nn
.
initializer
.
Uniform
(
-
bound
,
bound
)(
self
.
head
.
bias
)
def
no_weight_decay
(
self
):
layers
=
set
()
for
i
in
range
(
self
.
num_stages
):
layers
.
add
(
f
'stage
{
i
}
.pos_embed'
)
layers
.
add
(
f
'stage
{
i
}
.cls_token'
)
return
layers
def
forward_features
(
self
,
x
):
for
i
in
range
(
self
.
num_stages
):
x
,
cls_tokens
=
getattr
(
self
,
f
'stage
{
i
}
'
)(
x
)
if
self
.
cls_token
:
x
=
self
.
norm
(
cls_tokens
)
x
=
paddle
.
squeeze
(
x
,
axis
=
1
)
else
:
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
x
=
self
.
norm
(
x
)
x
=
paddle
.
mean
(
x
,
axis
=
1
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
cvt_13_224x224
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
msvit_spec
=
dict
(
INIT
=
'trunc_norm'
,
NUM_STAGES
=
3
,
PATCH_SIZE
=
[
7
,
3
,
3
],
PATCH_STRIDE
=
[
4
,
2
,
2
],
PATCH_PADDING
=
[
2
,
1
,
1
],
DIM_EMBED
=
[
64
,
192
,
384
],
NUM_HEADS
=
[
1
,
3
,
6
],
DEPTH
=
[
1
,
2
,
10
],
MLP_RATIO
=
[
4.0
,
4.0
,
4.0
],
ATTN_DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_PATH_RATE
=
[
0.0
,
0.0
,
0.1
],
QKV_BIAS
=
[
True
,
True
,
True
],
CLS_TOKEN
=
[
False
,
False
,
True
],
POS_EMBED
=
[
False
,
False
,
False
],
QKV_PROJ_METHOD
=
[
'dw_bn'
,
'dw_bn'
,
'dw_bn'
],
KERNEL_QKV
=
[
3
,
3
,
3
],
PADDING_KV
=
[
1
,
1
,
1
],
STRIDE_KV
=
[
2
,
2
,
2
],
PADDING_Q
=
[
1
,
1
,
1
],
STRIDE_Q
=
[
1
,
1
,
1
])
model
=
ConvolutionalVisionTransformer
(
in_chans
=
3
,
act_layer
=
QuickGELU
,
init
=
msvit_spec
.
get
(
'INIT'
,
'trunc_norm'
),
spec
=
msvit_spec
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"cvt_13_224x224"
],
use_ssld
=
use_ssld
)
return
model
def
cvt_13_384x384
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
msvit_spec
=
dict
(
INIT
=
'trunc_norm'
,
NUM_STAGES
=
3
,
PATCH_SIZE
=
[
7
,
3
,
3
],
PATCH_STRIDE
=
[
4
,
2
,
2
],
PATCH_PADDING
=
[
2
,
1
,
1
],
DIM_EMBED
=
[
64
,
192
,
384
],
NUM_HEADS
=
[
1
,
3
,
6
],
DEPTH
=
[
1
,
2
,
10
],
MLP_RATIO
=
[
4.0
,
4.0
,
4.0
],
ATTN_DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_PATH_RATE
=
[
0.0
,
0.0
,
0.1
],
QKV_BIAS
=
[
True
,
True
,
True
],
CLS_TOKEN
=
[
False
,
False
,
True
],
POS_EMBED
=
[
False
,
False
,
False
],
QKV_PROJ_METHOD
=
[
'dw_bn'
,
'dw_bn'
,
'dw_bn'
],
KERNEL_QKV
=
[
3
,
3
,
3
],
PADDING_KV
=
[
1
,
1
,
1
],
STRIDE_KV
=
[
2
,
2
,
2
],
PADDING_Q
=
[
1
,
1
,
1
],
STRIDE_Q
=
[
1
,
1
,
1
])
model
=
ConvolutionalVisionTransformer
(
in_chans
=
3
,
act_layer
=
QuickGELU
,
init
=
msvit_spec
.
get
(
'INIT'
,
'trunc_norm'
),
spec
=
msvit_spec
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"cvt_13_384x384"
],
use_ssld
=
use_ssld
)
return
model
def
cvt_21_224x224
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
msvit_spec
=
dict
(
INIT
=
'trunc_norm'
,
NUM_STAGES
=
3
,
PATCH_SIZE
=
[
7
,
3
,
3
],
PATCH_STRIDE
=
[
4
,
2
,
2
],
PATCH_PADDING
=
[
2
,
1
,
1
],
DIM_EMBED
=
[
64
,
192
,
384
],
NUM_HEADS
=
[
1
,
3
,
6
],
DEPTH
=
[
1
,
4
,
16
],
MLP_RATIO
=
[
4.0
,
4.0
,
4.0
],
ATTN_DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_PATH_RATE
=
[
0.0
,
0.0
,
0.1
],
QKV_BIAS
=
[
True
,
True
,
True
],
CLS_TOKEN
=
[
False
,
False
,
True
],
POS_EMBED
=
[
False
,
False
,
False
],
QKV_PROJ_METHOD
=
[
'dw_bn'
,
'dw_bn'
,
'dw_bn'
],
KERNEL_QKV
=
[
3
,
3
,
3
],
PADDING_KV
=
[
1
,
1
,
1
],
STRIDE_KV
=
[
2
,
2
,
2
],
PADDING_Q
=
[
1
,
1
,
1
],
STRIDE_Q
=
[
1
,
1
,
1
])
model
=
ConvolutionalVisionTransformer
(
in_chans
=
3
,
act_layer
=
QuickGELU
,
init
=
msvit_spec
.
get
(
'INIT'
,
'trunc_norm'
),
spec
=
msvit_spec
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"cvt_21_224x224"
],
use_ssld
=
use_ssld
)
return
model
def
cvt_21_384x384
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
msvit_spec
=
dict
(
INIT
=
'trunc_norm'
,
NUM_STAGES
=
3
,
PATCH_SIZE
=
[
7
,
3
,
3
],
PATCH_STRIDE
=
[
4
,
2
,
2
],
PATCH_PADDING
=
[
2
,
1
,
1
],
DIM_EMBED
=
[
64
,
192
,
384
],
NUM_HEADS
=
[
1
,
3
,
6
],
DEPTH
=
[
1
,
4
,
16
],
MLP_RATIO
=
[
4.0
,
4.0
,
4.0
],
ATTN_DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_RATE
=
[
0.0
,
0.0
,
0.0
],
DROP_PATH_RATE
=
[
0.0
,
0.0
,
0.1
],
QKV_BIAS
=
[
True
,
True
,
True
],
CLS_TOKEN
=
[
False
,
False
,
True
],
POS_EMBED
=
[
False
,
False
,
False
],
QKV_PROJ_METHOD
=
[
'dw_bn'
,
'dw_bn'
,
'dw_bn'
],
KERNEL_QKV
=
[
3
,
3
,
3
],
PADDING_KV
=
[
1
,
1
,
1
],
STRIDE_KV
=
[
2
,
2
,
2
],
PADDING_Q
=
[
1
,
1
,
1
],
STRIDE_Q
=
[
1
,
1
,
1
])
model
=
ConvolutionalVisionTransformer
(
in_chans
=
3
,
act_layer
=
QuickGELU
,
init
=
msvit_spec
.
get
(
'INIT'
,
'trunc_norm'
),
spec
=
msvit_spec
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"cvt_21_384x384"
],
use_ssld
=
use_ssld
)
return
model
ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml
0 → 100644
浏览文件 @
d7a11275
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
300
print_batch_step
:
50
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
update_freq
:
2
# for 8 cards
# model architecture
Arch
:
name
:
cvt_13_224x224
class_num
:
1000
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
epsilon
:
1e-8
weight_decay
:
0.05
no_weight_decay_name
:
stage1.pos_embed stage2.pos_embed stage0.pos_embed stage0.cls_token stage2.cls_token stage1.cls_token .bias
one_dim_param_no_weight_decay
:
True
lr
:
# for 8 cards
name
:
Cosine
learning_rate
:
2e-3
# lr 2e-3 for total_batch_size 2048
eta_min
:
1e-5
warmup_epoch
:
5
warmup_start_lr
:
1e-6
by_epoch
:
True
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.25
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.8
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
256
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/data/preprocess/ops/operators.py
浏览文件 @
d7a11275
...
...
@@ -188,7 +188,7 @@ class DecodeImage(object):
elif
isinstance
(
img
,
bytes
):
if
self
.
backend
==
"pil"
:
data
=
io
.
BytesIO
(
img
)
img
=
Image
.
open
(
data
)
img
=
Image
.
open
(
data
)
.
convert
(
"RGB"
)
else
:
data
=
np
.
frombuffer
(
img
,
dtype
=
"uint8"
)
img
=
cv2
.
imdecode
(
data
,
1
)
...
...
@@ -197,7 +197,7 @@ class DecodeImage(object):
if
self
.
to_np
:
if
self
.
backend
==
"pil"
:
assert
img
.
mode
==
"RGB"
,
f
"invalid
shape of image[
{
img
.
shap
e
}
]"
assert
img
.
mode
==
"RGB"
,
f
"invalid
mode of image[
{
img
.
mod
e
}
]"
img
=
np
.
asarray
(
img
)[:,
:,
::
-
1
]
# BRG
if
self
.
to_rgb
:
...
...
test_tipc/configs/CvT/cvt_13_224x224_train_infer_python.txt
0 → 100644
浏览文件 @
d7a11275
===========================train_params===========================
model_name:cvt_13_224x224
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
inference_dir:null
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.interpolation=bicubic -o PreProcess.transform_ops.0.ResizeImage.backend=pil
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录