提交 f24e1f9b 编写于 作者: D dolcexu 提交者: zengshao0622

cae config fix

上级 fb763b70
...@@ -630,17 +630,17 @@ def _load_pretrained(pretrained, ...@@ -630,17 +630,17 @@ def _load_pretrained(pretrained,
model, model,
model_keys, model_keys,
model_ema_configs, model_ema_configs,
abs_pos_emb, use_abs_pos_emb,
rel_pos_bias, use_rel_pos_bias,
use_ssld=False): use_ssld=False):
if pretrained is False: if pretrained is False:
pass return
elif pretrained is True: elif pretrained is True:
local_weight_path = get_weights_path_from_url(pretrained_url).replace( local_weight_path = get_weights_path_from_url(pretrained_url).replace(
".pdparams", "") ".pdparams", "")
checkpoint = paddle.load(local_weight_path + ".pdparams") checkpoint = paddle.load(local_weight_path + ".pdparams")
elif isinstance(pretrained, str): elif isinstance(pretrained, str):
checkpoint = paddle.load(local_weight_path + ".pdparams") checkpoint = paddle.load(pretrained + ".pdparams")
checkpoint_model = None checkpoint_model = None
for model_key in model_keys.split('|'): for model_key in model_keys.split('|'):
...@@ -693,10 +693,10 @@ def _load_pretrained(pretrained, ...@@ -693,10 +693,10 @@ def _load_pretrained(pretrained,
if "relative_position_index" in key: if "relative_position_index" in key:
checkpoint_model.pop(key) checkpoint_model.pop(key)
if "relative_position_bias_table" in key and rel_pos_bias: if "relative_position_bias_table" in key and use_rel_pos_bias:
rel_pos_bias = checkpoint_model[key] rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size() src_num_pos, num_attn_heads = rel_pos_bias.shape
dst_num_pos, _ = model.state_dict()[key].size() dst_num_pos, _ = model.state_dict()[key].shape
dst_patch_shape = model.patch_embed.patch_shape dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]: if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError() raise NotImplementedError()
...@@ -742,8 +742,8 @@ def _load_pretrained(pretrained, ...@@ -742,8 +742,8 @@ def _load_pretrained(pretrained,
src_size).float().numpy() src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic') f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append( all_rel_pos_bias.append(
paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to( paddle.Tensor(f(dx, dy)).astype('float32').reshape(
rel_pos_bias.device)) [-1, 1]))
rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1) rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1)
...@@ -752,7 +752,7 @@ def _load_pretrained(pretrained, ...@@ -752,7 +752,7 @@ def _load_pretrained(pretrained,
checkpoint_model[key] = new_rel_pos_bias checkpoint_model[key] = new_rel_pos_bias
# interpolate position embedding # interpolate position embedding
if 'pos_embed' in checkpoint_model and abs_pos_emb: if 'pos_embed' in checkpoint_model and use_abs_pos_emb:
pos_embed_checkpoint = checkpoint_model['pos_embed'] pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1] embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches num_patches = model.patch_embed.num_patches
...@@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs): ...@@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval = config.pop('enable_linear_eval') enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key') model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema') model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb') use_abs_pos_emb = config.get('use_abs_pos_emb', False)
rel_pos_bias = config.pop('rel_pos_bias') use_rel_pos_bias = config.get('use_rel_pos_bias', True)
if pretrained in config: if pretrained in config:
pretrained = config.pop('pretrained') pretrained = config.pop('pretrained')
...@@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs): ...@@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model, model,
model_keys, model_keys,
model_ema_configs, model_ema_configs,
abs_pos_emb, use_abs_pos_emb,
rel_pos_bias, use_rel_pos_bias,
use_ssld=False) use_ssld=False)
return model return model
...@@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs): ...@@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
enable_linear_eval = config.pop('enable_linear_eval') enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key') model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema') model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb') use_abs_pos_emb = config.get('use_abs_pos_emb', False)
rel_pos_bias = config.pop('rel_pos_bias') use_rel_pos_bias = config.get('use_rel_pos_bias', True)
if pretrained in config: if pretrained in config:
pretrained = config.pop('pretrained') pretrained = config.pop('pretrained')
...@@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs): ...@@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
model, model,
model_keys, model_keys,
model_ema_configs, model_ema_configs,
abs_pos_emb, use_abs_pos_emb,
rel_pos_bias, use_rel_pos_bias,
use_ssld=False) use_ssld=False)
return model return model
...@@ -31,10 +31,8 @@ Arch: ...@@ -31,10 +31,8 @@ Arch:
sin_pos_emb: True sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False enable_linear_eval: False
model_key: model|module|state_dict model_key: model|module|state_dict
rel_pos_bias: True
model_ema: model_ema:
enable_model_ema: False enable_model_ema: False
model_ema_decay: 0.9999 model_ema_decay: 0.9999
...@@ -83,23 +81,27 @@ DataLoader: ...@@ -83,23 +81,27 @@ DataLoader:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
channel_first: False channel_first: False
- RandCropImage: - RandomResizedCrop:
size: 224 size: 224
interpolation: bilinear - RandomHorizontalFlip:
- RandFlipImage: prob: 0.5
flip_code: 1 - TimmAutoAugment:
- RandAugment: config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage: - NormalizeImage:
scale: 1.0/255.0 scale: 1.0/255.0
mean: [0.485, 0.456, 0.406] mean: [ 0.5, 0.5, 0.5 ]
std: [0.229, 0.224, 0.225] std: [ 0.5, 0.5, 0.5 ]
order: '' order: ''
- RandomErasing: - RandomErasing:
EPSILON: 0.5 EPSILON: 0.25
sl: 0.02 sl: 0.02
sh: 0.3 sh: 1.0/3.0
r1: 0.3 r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 16 batch_size: 16
...@@ -110,7 +112,7 @@ DataLoader: ...@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory: True use_shared_memory: True
Eval: Eval:
dataset: dataset:
name: ImageNetDataset name: ImageNetDataset
image_root: ./dataset/flowers102/ image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt cls_label_path: ./dataset/flowers102/val_list.txt
...@@ -124,8 +126,8 @@ DataLoader: ...@@ -124,8 +126,8 @@ DataLoader:
size: 224 size: 224
- NormalizeImage: - NormalizeImage:
scale: 1.0/255.0 scale: 1.0/255.0
mean: [0.485, 0.456, 0.406] mean: [ 0.5, 0.5, 0.5 ]
std: [0.229, 0.224, 0.225] std: [ 0.5, 0.5, 0.5 ]
order: '' order: ''
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
......
...@@ -31,10 +31,8 @@ Arch: ...@@ -31,10 +31,8 @@ Arch:
sin_pos_emb: True sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False enable_linear_eval: False
model_key: model|module|state_dict model_key: model|module|state_dict
rel_pos_bias: True
model_ema: model_ema:
enable_model_ema: False enable_model_ema: False
model_ema_decay: 0.9999 model_ema_decay: 0.9999
...@@ -83,23 +81,27 @@ DataLoader: ...@@ -83,23 +81,27 @@ DataLoader:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
channel_first: False channel_first: False
- RandCropImage: - RandomResizedCrop:
size: 224 size: 224
interpolation: bilinear - RandomHorizontalFlip:
- RandFlipImage: prob: 0.5
flip_code: 1 - TimmAutoAugment:
- RandAugment: config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage: - NormalizeImage:
scale: 1.0/255.0 scale: 1.0/255.0
mean: [0.485, 0.456, 0.406] mean: [ 0.5, 0.5, 0.5 ]
std: [0.229, 0.224, 0.225] std: [ 0.5, 0.5, 0.5 ]
order: '' order: ''
- RandomErasing: - RandomErasing:
EPSILON: 0.5 EPSILON: 0.25
sl: 0.02 sl: 0.02
sh: 0.3 sh: 1.0/3.0
r1: 0.3 r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 16 batch_size: 16
...@@ -110,7 +112,7 @@ DataLoader: ...@@ -110,7 +112,7 @@ DataLoader:
use_shared_memory: True use_shared_memory: True
Eval: Eval:
dataset: dataset:
name: ImageNetDataset name: ImageNetDataset
image_root: ./dataset/flowers102/ image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt cls_label_path: ./dataset/flowers102/val_list.txt
...@@ -124,8 +126,8 @@ DataLoader: ...@@ -124,8 +126,8 @@ DataLoader:
size: 224 size: 224
- NormalizeImage: - NormalizeImage:
scale: 1.0/255.0 scale: 1.0/255.0
mean: [0.485, 0.456, 0.406] mean: [ 0.5, 0.5, 0.5 ]
std: [0.229, 0.224, 0.225] std: [ 0.5, 0.5, 0.5 ]
order: '' order: ''
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册