提交 f24e1f9b 编写于 作者: D dolcexu 提交者: zengshao0622

cae config fix

上级 fb763b70
......@@ -630,17 +630,17 @@ def _load_pretrained(pretrained,
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_abs_pos_emb,
use_rel_pos_bias,
use_ssld=False):
if pretrained is False:
pass
return
elif pretrained is True:
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
".pdparams", "")
checkpoint = paddle.load(local_weight_path + ".pdparams")
elif isinstance(pretrained, str):
checkpoint = paddle.load(local_weight_path + ".pdparams")
checkpoint = paddle.load(pretrained + ".pdparams")
checkpoint_model = None
for model_key in model_keys.split('|'):
......@@ -693,10 +693,10 @@ def _load_pretrained(pretrained,
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key and rel_pos_bias:
if "relative_position_bias_table" in key and use_rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
src_num_pos, num_attn_heads = rel_pos_bias.shape
dst_num_pos, _ = model.state_dict()[key].shape
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
......@@ -742,8 +742,8 @@ def _load_pretrained(pretrained,
src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
rel_pos_bias.device))
paddle.Tensor(f(dx, dy)).astype('float32').reshape(
[-1, 1]))
rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1)
......@@ -752,7 +752,7 @@ def _load_pretrained(pretrained,
checkpoint_model[key] = new_rel_pos_bias
# interpolate position embedding
if 'pos_embed' in checkpoint_model and abs_pos_emb:
if 'pos_embed' in checkpoint_model and use_abs_pos_emb:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
......@@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb')
rel_pos_bias = config.pop('rel_pos_bias')
use_abs_pos_emb = config.get('use_abs_pos_emb', False)
use_rel_pos_bias = config.get('use_rel_pos_bias', True)
if pretrained in config:
pretrained = config.pop('pretrained')
......@@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_abs_pos_emb,
use_rel_pos_bias,
use_ssld=False)
return model
......@@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb')
rel_pos_bias = config.pop('rel_pos_bias')
use_abs_pos_emb = config.get('use_abs_pos_emb', False)
use_rel_pos_bias = config.get('use_rel_pos_bias', True)
if pretrained in config:
pretrained = config.pop('pretrained')
......@@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_abs_pos_emb,
use_rel_pos_bias,
use_ssld=False)
return model
......@@ -31,10 +31,8 @@ Arch:
sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False
model_key: model|module|state_dict
rel_pos_bias: True
model_ema:
enable_model_ema: False
model_ema_decay: 0.9999
......@@ -83,23 +81,27 @@ DataLoader:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
- RandomResizedCrop:
size: 224
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- RandAugment:
- RandomHorizontalFlip:
prob: 0.5
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
order: ''
- RandomErasing:
EPSILON: 0.5
EPSILON: 0.25
sl: 0.02
sh: 0.3
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler:
name: DistributedBatchSampler
batch_size: 16
......@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory: True
Eval:
dataset:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt
......@@ -124,8 +126,8 @@ DataLoader:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
order: ''
sampler:
name: DistributedBatchSampler
......
......@@ -31,10 +31,8 @@ Arch:
sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False
model_key: model|module|state_dict
rel_pos_bias: True
model_ema:
enable_model_ema: False
model_ema_decay: 0.9999
......@@ -83,23 +81,27 @@ DataLoader:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
- RandomResizedCrop:
size: 224
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- RandAugment:
- RandomHorizontalFlip:
prob: 0.5
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
order: ''
- RandomErasing:
EPSILON: 0.5
EPSILON: 0.25
sl: 0.02
sh: 0.3
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler:
name: DistributedBatchSampler
batch_size: 16
......@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory: True
Eval:
dataset:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt
......@@ -124,8 +126,8 @@ DataLoader:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
order: ''
sampler:
name: DistributedBatchSampler
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册