From f24e1f9bcfc685a2300cf3fa3c9731850187e5a1 Mon Sep 17 00:00:00 2001 From: dolcexu Date: Thu, 23 Feb 2023 10:40:53 +0800 Subject: [PATCH] cae config fix --- ppcls/arch/backbone/model_zoo/cae.py | 36 +++++++++---------- .../CAE/cae_base_patch16_224_finetune.yaml | 32 +++++++++-------- .../CAE/cae_large_patch16_224_finetune.yaml | 32 +++++++++-------- 3 files changed, 52 insertions(+), 48 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/cae.py b/ppcls/arch/backbone/model_zoo/cae.py index b6262e9f..12621cf9 100644 --- a/ppcls/arch/backbone/model_zoo/cae.py +++ b/ppcls/arch/backbone/model_zoo/cae.py @@ -630,17 +630,17 @@ def _load_pretrained(pretrained, model, model_keys, model_ema_configs, - abs_pos_emb, - rel_pos_bias, + use_abs_pos_emb, + use_rel_pos_bias, use_ssld=False): if pretrained is False: - pass + return elif pretrained is True: local_weight_path = get_weights_path_from_url(pretrained_url).replace( ".pdparams", "") checkpoint = paddle.load(local_weight_path + ".pdparams") elif isinstance(pretrained, str): - checkpoint = paddle.load(local_weight_path + ".pdparams") + checkpoint = paddle.load(pretrained + ".pdparams") checkpoint_model = None for model_key in model_keys.split('|'): @@ -693,10 +693,10 @@ def _load_pretrained(pretrained, if "relative_position_index" in key: checkpoint_model.pop(key) - if "relative_position_bias_table" in key and rel_pos_bias: + if "relative_position_bias_table" in key and use_rel_pos_bias: rel_pos_bias = checkpoint_model[key] - src_num_pos, num_attn_heads = rel_pos_bias.size() - dst_num_pos, _ = model.state_dict()[key].size() + src_num_pos, num_attn_heads = rel_pos_bias.shape + dst_num_pos, _ = model.state_dict()[key].shape dst_patch_shape = model.patch_embed.patch_shape if dst_patch_shape[0] != dst_patch_shape[1]: raise NotImplementedError() @@ -742,8 +742,8 @@ def _load_pretrained(pretrained, src_size).float().numpy() f = interpolate.interp2d(x, y, z, kind='cubic') all_rel_pos_bias.append( - paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to( - rel_pos_bias.device)) + paddle.Tensor(f(dx, dy)).astype('float32').reshape( + [-1, 1])) rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1) @@ -752,7 +752,7 @@ def _load_pretrained(pretrained, checkpoint_model[key] = new_rel_pos_bias # interpolate position embedding - if 'pos_embed' in checkpoint_model and abs_pos_emb: + if 'pos_embed' in checkpoint_model and use_abs_pos_emb: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches @@ -791,8 +791,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs): enable_linear_eval = config.pop('enable_linear_eval') model_keys = config.pop('model_key') model_ema_configs = config.pop('model_ema') - abs_pos_emb = config.pop('abs_pos_emb') - rel_pos_bias = config.pop('rel_pos_bias') + use_abs_pos_emb = config.get('use_abs_pos_emb', False) + use_rel_pos_bias = config.get('use_rel_pos_bias', True) if pretrained in config: pretrained = config.pop('pretrained') @@ -816,8 +816,8 @@ def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs): model, model_keys, model_ema_configs, - abs_pos_emb, - rel_pos_bias, + use_abs_pos_emb, + use_rel_pos_bias, use_ssld=False) return model @@ -828,8 +828,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs): enable_linear_eval = config.pop('enable_linear_eval') model_keys = config.pop('model_key') model_ema_configs = config.pop('model_ema') - abs_pos_emb = config.pop('abs_pos_emb') - rel_pos_bias = config.pop('rel_pos_bias') + use_abs_pos_emb = config.get('use_abs_pos_emb', False) + use_rel_pos_bias = config.get('use_rel_pos_bias', True) if pretrained in config: pretrained = config.pop('pretrained') @@ -853,8 +853,8 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs): model, model_keys, model_ema_configs, - abs_pos_emb, - rel_pos_bias, + use_abs_pos_emb, + use_rel_pos_bias, use_ssld=False) return model diff --git a/ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml b/ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml index 7ec1c9c4..99447275 100644 --- a/ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml +++ b/ppcls/configs/CAE/cae_base_patch16_224_finetune.yaml @@ -31,10 +31,8 @@ Arch: sin_pos_emb: True - abs_pos_emb: False enable_linear_eval: False model_key: model|module|state_dict - rel_pos_bias: True model_ema: enable_model_ema: False model_ema_decay: 0.9999 @@ -83,23 +81,27 @@ DataLoader: - DecodeImage: to_rgb: True channel_first: False - - RandCropImage: + - RandomResizedCrop: size: 224 - interpolation: bilinear - - RandFlipImage: - flip_code: 1 - - RandAugment: + - RandomHorizontalFlip: + prob: 0.5 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 - NormalizeImage: scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] order: '' - RandomErasing: - EPSILON: 0.5 + EPSILON: 0.25 sl: 0.02 - sh: 0.3 + sh: 1.0/3.0 r1: 0.3 - + attempt: 10 + use_log_aspect: True + mode: pixel sampler: name: DistributedBatchSampler batch_size: 16 @@ -110,7 +112,7 @@ DataLoader: use_shared_memory: True Eval: - dataset: + dataset: name: ImageNetDataset image_root: ./dataset/flowers102/ cls_label_path: ./dataset/flowers102/val_list.txt @@ -124,8 +126,8 @@ DataLoader: size: 224 - NormalizeImage: scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] order: '' sampler: name: DistributedBatchSampler diff --git a/ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml b/ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml index f8f7edc5..579163d9 100644 --- a/ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml +++ b/ppcls/configs/CAE/cae_large_patch16_224_finetune.yaml @@ -31,10 +31,8 @@ Arch: sin_pos_emb: True - abs_pos_emb: False enable_linear_eval: False model_key: model|module|state_dict - rel_pos_bias: True model_ema: enable_model_ema: False model_ema_decay: 0.9999 @@ -83,23 +81,27 @@ DataLoader: - DecodeImage: to_rgb: True channel_first: False - - RandCropImage: + - RandomResizedCrop: size: 224 - interpolation: bilinear - - RandFlipImage: - flip_code: 1 - - RandAugment: + - RandomHorizontalFlip: + prob: 0.5 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 - NormalizeImage: scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] order: '' - RandomErasing: - EPSILON: 0.5 + EPSILON: 0.25 sl: 0.02 - sh: 0.3 + sh: 1.0/3.0 r1: 0.3 - + attempt: 10 + use_log_aspect: True + mode: pixel sampler: name: DistributedBatchSampler batch_size: 16 @@ -110,7 +112,7 @@ DataLoader: use_shared_memory: True Eval: - dataset: + dataset: name: ImageNetDataset image_root: ./dataset/flowers102/ cls_label_path: ./dataset/flowers102/val_list.txt @@ -124,8 +126,8 @@ DataLoader: size: 224 - NormalizeImage: scale: 1.0/255.0 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] order: '' sampler: name: DistributedBatchSampler -- GitLab