Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
f24e1f9b
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f24e1f9b
编写于
2月 23, 2023
作者:
D
dolcexu
提交者:
zengshao0622
2月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cae config fix
上级
fb763b70
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
52 addition
and
48 deletion
+52
-48
ppcls/arch/backbone/model_zoo/cae.py
ppcls/arch/backbone/model_zoo/cae.py
+18
-18
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
+17
-15
ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml
ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml
+17
-15
未找到文件。
ppcls/arch/backbone/model_zoo/cae.py
浏览文件 @
f24e1f9b
...
@@ -630,17 +630,17 @@ def _load_pretrained(pretrained,
...
@@ -630,17 +630,17 @@ def _load_pretrained(pretrained,
model
,
model
,
model_keys
,
model_keys
,
model_ema_configs
,
model_ema_configs
,
abs_pos_emb
,
use_
abs_pos_emb
,
rel_pos_bias
,
use_
rel_pos_bias
,
use_ssld
=
False
):
use_ssld
=
False
):
if
pretrained
is
False
:
if
pretrained
is
False
:
pass
return
elif
pretrained
is
True
:
elif
pretrained
is
True
:
local_weight_path
=
get_weights_path_from_url
(
pretrained_url
).
replace
(
local_weight_path
=
get_weights_path_from_url
(
pretrained_url
).
replace
(
".pdparams"
,
""
)
".pdparams"
,
""
)
checkpoint
=
paddle
.
load
(
local_weight_path
+
".pdparams"
)
checkpoint
=
paddle
.
load
(
local_weight_path
+
".pdparams"
)
elif
isinstance
(
pretrained
,
str
):
elif
isinstance
(
pretrained
,
str
):
checkpoint
=
paddle
.
load
(
local_weight_path
+
".pdparams"
)
checkpoint
=
paddle
.
load
(
pretrained
+
".pdparams"
)
checkpoint_model
=
None
checkpoint_model
=
None
for
model_key
in
model_keys
.
split
(
'|'
):
for
model_key
in
model_keys
.
split
(
'|'
):
...
@@ -693,10 +693,10 @@ def _load_pretrained(pretrained,
...
@@ -693,10 +693,10 @@ def _load_pretrained(pretrained,
if
"relative_position_index"
in
key
:
if
"relative_position_index"
in
key
:
checkpoint_model
.
pop
(
key
)
checkpoint_model
.
pop
(
key
)
if
"relative_position_bias_table"
in
key
and
rel_pos_bias
:
if
"relative_position_bias_table"
in
key
and
use_
rel_pos_bias
:
rel_pos_bias
=
checkpoint_model
[
key
]
rel_pos_bias
=
checkpoint_model
[
key
]
src_num_pos
,
num_attn_heads
=
rel_pos_bias
.
s
ize
()
src_num_pos
,
num_attn_heads
=
rel_pos_bias
.
s
hape
dst_num_pos
,
_
=
model
.
state_dict
()[
key
].
s
ize
()
dst_num_pos
,
_
=
model
.
state_dict
()[
key
].
s
hape
dst_patch_shape
=
model
.
patch_embed
.
patch_shape
dst_patch_shape
=
model
.
patch_embed
.
patch_shape
if
dst_patch_shape
[
0
]
!=
dst_patch_shape
[
1
]:
if
dst_patch_shape
[
0
]
!=
dst_patch_shape
[
1
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -742,8 +742,8 @@ def _load_pretrained(pretrained,
...
@@ -742,8 +742,8 @@ def _load_pretrained(pretrained,
src_size
).
float
().
numpy
()
src_size
).
float
().
numpy
()
f
=
interpolate
.
interp2d
(
x
,
y
,
z
,
kind
=
'cubic'
)
f
=
interpolate
.
interp2d
(
x
,
y
,
z
,
kind
=
'cubic'
)
all_rel_pos_bias
.
append
(
all_rel_pos_bias
.
append
(
paddle
.
Tensor
(
f
(
dx
,
dy
)).
contiguous
().
view
(
-
1
,
1
).
to
(
paddle
.
Tensor
(
f
(
dx
,
dy
)).
astype
(
'float32'
).
reshape
(
rel_pos_bias
.
device
))
[
-
1
,
1
]
))
rel_pos_bias
=
paddle
.
concat
(
all_rel_pos_bias
,
axis
=-
1
)
rel_pos_bias
=
paddle
.
concat
(
all_rel_pos_bias
,
axis
=-
1
)
...
@@ -752,7 +752,7 @@ def _load_pretrained(pretrained,
...
@@ -752,7 +752,7 @@ def _load_pretrained(pretrained,
checkpoint_model
[
key
]
=
new_rel_pos_bias
checkpoint_model
[
key
]
=
new_rel_pos_bias
# interpolate position embedding
# interpolate position embedding
if
'pos_embed'
in
checkpoint_model
and
abs_pos_emb
:
if
'pos_embed'
in
checkpoint_model
and
use_
abs_pos_emb
:
pos_embed_checkpoint
=
checkpoint_model
[
'pos_embed'
]
pos_embed_checkpoint
=
checkpoint_model
[
'pos_embed'
]
embedding_size
=
pos_embed_checkpoint
.
shape
[
-
1
]
embedding_size
=
pos_embed_checkpoint
.
shape
[
-
1
]
num_patches
=
model
.
patch_embed
.
num_patches
num_patches
=
model
.
patch_embed
.
num_patches
...
@@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
...
@@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval
=
config
.
pop
(
'enable_linear_eval'
)
enable_linear_eval
=
config
.
pop
(
'enable_linear_eval'
)
model_keys
=
config
.
pop
(
'model_key'
)
model_keys
=
config
.
pop
(
'model_key'
)
model_ema_configs
=
config
.
pop
(
'model_ema'
)
model_ema_configs
=
config
.
pop
(
'model_ema'
)
abs_pos_emb
=
config
.
pop
(
'abs_pos_emb'
)
use_abs_pos_emb
=
config
.
get
(
'use_abs_pos_emb'
,
False
)
rel_pos_bias
=
config
.
pop
(
'rel_pos_bias'
)
use_rel_pos_bias
=
config
.
get
(
'use_rel_pos_bias'
,
True
)
if
pretrained
in
config
:
if
pretrained
in
config
:
pretrained
=
config
.
pop
(
'pretrained'
)
pretrained
=
config
.
pop
(
'pretrained'
)
...
@@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
...
@@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model
,
model
,
model_keys
,
model_keys
,
model_ema_configs
,
model_ema_configs
,
abs_pos_emb
,
use_
abs_pos_emb
,
rel_pos_bias
,
use_
rel_pos_bias
,
use_ssld
=
False
)
use_ssld
=
False
)
return
model
return
model
...
@@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
...
@@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval
=
config
.
pop
(
'enable_linear_eval'
)
enable_linear_eval
=
config
.
pop
(
'enable_linear_eval'
)
model_keys
=
config
.
pop
(
'model_key'
)
model_keys
=
config
.
pop
(
'model_key'
)
model_ema_configs
=
config
.
pop
(
'model_ema'
)
model_ema_configs
=
config
.
pop
(
'model_ema'
)
abs_pos_emb
=
config
.
pop
(
'abs_pos_emb'
)
use_abs_pos_emb
=
config
.
get
(
'use_abs_pos_emb'
,
False
)
rel_pos_bias
=
config
.
pop
(
'rel_pos_bias'
)
use_rel_pos_bias
=
config
.
get
(
'use_rel_pos_bias'
,
True
)
if
pretrained
in
config
:
if
pretrained
in
config
:
pretrained
=
config
.
pop
(
'pretrained'
)
pretrained
=
config
.
pop
(
'pretrained'
)
...
@@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
...
@@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model
,
model
,
model_keys
,
model_keys
,
model_ema_configs
,
model_ema_configs
,
abs_pos_emb
,
use_
abs_pos_emb
,
rel_pos_bias
,
use_
rel_pos_bias
,
use_ssld
=
False
)
use_ssld
=
False
)
return
model
return
model
ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml
浏览文件 @
f24e1f9b
...
@@ -31,10 +31,8 @@ Arch:
...
@@ -31,10 +31,8 @@ Arch:
sin_pos_emb
:
True
sin_pos_emb
:
True
abs_pos_emb
:
False
enable_linear_eval
:
False
enable_linear_eval
:
False
model_key
:
model|module|state_dict
model_key
:
model|module|state_dict
rel_pos_bias
:
True
model_ema
:
model_ema
:
enable_model_ema
:
False
enable_model_ema
:
False
model_ema_decay
:
0.9999
model_ema_decay
:
0.9999
...
@@ -83,23 +81,27 @@ DataLoader:
...
@@ -83,23 +81,27 @@ DataLoader:
-
DecodeImage
:
-
DecodeImage
:
to_rgb
:
True
to_rgb
:
True
channel_first
:
False
channel_first
:
False
-
Rand
CropImage
:
-
Rand
omResizedCrop
:
size
:
224
size
:
224
interpolation
:
bilinear
-
RandomHorizontalFlip
:
-
RandFlipImage
:
prob
:
0.5
flip_code
:
1
-
TimmAutoAugment
:
-
RandAugment
:
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1.0/255.0
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
order
:
'
'
-
RandomErasing
:
-
RandomErasing
:
EPSILON
:
0.5
EPSILON
:
0.
2
5
sl
:
0.02
sl
:
0.02
sh
:
0.3
sh
:
1.0/3.0
r1
:
0.3
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
sampler
:
name
:
DistributedBatchSampler
name
:
DistributedBatchSampler
batch_size
:
16
batch_size
:
16
...
@@ -110,7 +112,7 @@ DataLoader:
...
@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory
:
True
use_shared_memory
:
True
Eval
:
Eval
:
dataset
:
dataset
:
name
:
ImageNetDataset
name
:
ImageNetDataset
image_root
:
./dataset/flowers102/
image_root
:
./dataset/flowers102/
cls_label_path
:
./dataset/flowers102/val_list.txt
cls_label_path
:
./dataset/flowers102/val_list.txt
...
@@ -124,8 +126,8 @@ DataLoader:
...
@@ -124,8 +126,8 @@ DataLoader:
size
:
224
size
:
224
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1.0/255.0
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
order
:
'
'
sampler
:
sampler
:
name
:
DistributedBatchSampler
name
:
DistributedBatchSampler
...
...
ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml
浏览文件 @
f24e1f9b
...
@@ -31,10 +31,8 @@ Arch:
...
@@ -31,10 +31,8 @@ Arch:
sin_pos_emb
:
True
sin_pos_emb
:
True
abs_pos_emb
:
False
enable_linear_eval
:
False
enable_linear_eval
:
False
model_key
:
model|module|state_dict
model_key
:
model|module|state_dict
rel_pos_bias
:
True
model_ema
:
model_ema
:
enable_model_ema
:
False
enable_model_ema
:
False
model_ema_decay
:
0.9999
model_ema_decay
:
0.9999
...
@@ -83,23 +81,27 @@ DataLoader:
...
@@ -83,23 +81,27 @@ DataLoader:
-
DecodeImage
:
-
DecodeImage
:
to_rgb
:
True
to_rgb
:
True
channel_first
:
False
channel_first
:
False
-
Rand
CropImage
:
-
Rand
omResizedCrop
:
size
:
224
size
:
224
interpolation
:
bilinear
-
RandomHorizontalFlip
:
-
RandFlipImage
:
prob
:
0.5
flip_code
:
1
-
TimmAutoAugment
:
-
RandAugment
:
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1.0/255.0
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
order
:
'
'
-
RandomErasing
:
-
RandomErasing
:
EPSILON
:
0.5
EPSILON
:
0.
2
5
sl
:
0.02
sl
:
0.02
sh
:
0.3
sh
:
1.0/3.0
r1
:
0.3
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
sampler
:
name
:
DistributedBatchSampler
name
:
DistributedBatchSampler
batch_size
:
16
batch_size
:
16
...
@@ -110,7 +112,7 @@ DataLoader:
...
@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory
:
True
use_shared_memory
:
True
Eval
:
Eval
:
dataset
:
dataset
:
name
:
ImageNetDataset
name
:
ImageNetDataset
image_root
:
./dataset/flowers102/
image_root
:
./dataset/flowers102/
cls_label_path
:
./dataset/flowers102/val_list.txt
cls_label_path
:
./dataset/flowers102/val_list.txt
...
@@ -124,8 +126,8 @@ DataLoader:
...
@@ -124,8 +126,8 @@ DataLoader:
size
:
224
size
:
224
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1.0/255.0
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
'
order
:
'
'
sampler
:
sampler
:
name
:
DistributedBatchSampler
name
:
DistributedBatchSampler
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录