未验证 提交 ca5588c1 编写于 作者: Z zengshao0622 提交者: GitHub

Merge pull request #2286 from zengshao0622/merge_CAE

fix cae_large
...@@ -842,7 +842,7 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs): ...@@ -842,7 +842,7 @@ def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
qkv_bias=True, qkv_bias=True,
norm_layer=partial( norm_layer=partial(
nn.LayerNorm, epsilon=1e-6), nn.LayerNorm, epsilon=1e-6),
**kwargs) **config)
if enable_linear_eval: if enable_linear_eval:
_enable_linear_eval(model) _enable_linear_eval(model)
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 20
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: cae_large_patch16_224
class_num: 102
drop_rate: 0.0
drop_path_rate: 0.2
attn_drop_rate: 0.0
use_mean_pooling: True
init_scale: 0.001
use_rel_pos_bias: True
use_abs_pos_emb: False
init_values: 0.1
lin_probe: False
sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False
model_key: model|module|state_dict
rel_pos_bias: True
model_ema:
enable_model_ema: False
model_ema_decay: 0.9999
model_ema_force_cpu: False
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- SoftTargetCrossEntropy:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamWDL
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.05
layerwise_decay: 0.75
lr:
name: Cosine
learning_rate: 0.001
eta_min: 1e-6
warmup_epoch: 10
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/train_list.txt
batch_transform_ops:
- MixupCutmixHybrid:
mixup_alpha: 0.8
cutmix_alpha: 1.0
switch_prob: 0.5
num_classes: 102
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- RandAugment:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.3
r1: 0.3
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册