diff --git a/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..827029b7dd24ff1a28c98db24b31b7e01ec15925 --- /dev/null +++ b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml @@ -0,0 +1,70 @@ +mode: 'train' +ARCHITECTURE: + name: 'MobileNetV3_large_x1_0' +pretrained_model: "./pretrained/MobileNetV3_large_x1_0_pretrained" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f025437922a9a67bb71ed94fc25d40317d0f6a7 --- /dev/null +++ b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml @@ -0,0 +1,75 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0' + +pretrained_model: + - "./pretrain/flowers102_R50_vd_final/ppcls" + - "./pretrained/MobileNetV3_large_x1_0_pretrained/" +model_save_dir: "./output/" +classes_num: 102 +total_images: 7169 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +use_distillation: True + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.0125 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.00007 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_test_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/ResNet50_vd.yaml b/configs/quick_start/ResNet50_vd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..913090921da7111d1e0625016158fec8af8c8bcf --- /dev/null +++ b/configs/quick_start/ResNet50_vd.yaml @@ -0,0 +1,70 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.0125 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.00001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/ResNet50_vd_finetune.yaml b/configs/quick_start/ResNet50_vd_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..415e0e80d1d38da991c6431f406e5d39dc03bb6f --- /dev/null +++ b/configs/quick_start/ResNet50_vd_finetune.yaml @@ -0,0 +1,70 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' +pretrained_model: "./pretrained/ResNet50_vd_pretrained" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/ResNet50_vd_ssld_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6bf863414a23dd9de3e3c490e5c09a7e20e01c7 --- /dev/null +++ b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml @@ -0,0 +1,72 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' + params: + lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] +pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..629f050eda26edae9efe512980dc7faea28fc39f --- /dev/null +++ b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' + params: + lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] +pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 20 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index ae2ac733536b72b3087bf9001e1408caab020723..34be3ce2addc0a8320e4b3606e1d87f0873d2a47 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -44,4 +44,4 @@ from .darts_gs import DARTS_GS_6M, DARTS_GS_4M from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet # distillation model -from .distillation_models import ResNet50_vd_distill_MobileNetV3_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd +from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd diff --git a/ppcls/modeling/architectures/distillation_models.py b/ppcls/modeling/architectures/distillation_models.py index 258627f8059eb9804f5b7cf15f6b44f621029b58..f5f24b36a260f7d816a164dd0a8e86266550b0dc 100644 --- a/ppcls/modeling/architectures/distillation_models.py +++ b/ppcls/modeling/architectures/distillation_models.py @@ -27,12 +27,12 @@ from .mobilenet_v3 import MobileNetV3_large_x1_0 from .resnext101_wsl import ResNeXt101_32x16d_wsl __all__ = [ - 'ResNet50_vd_distill_MobileNetV3_x1_0', + 'ResNet50_vd_distill_MobileNetV3_large_x1_0', 'ResNeXt101_32x16d_wsl_distill_ResNet50_vd' ] -class ResNet50_vd_distill_MobileNetV3_x1_0(): +class ResNet50_vd_distill_MobileNetV3_large_x1_0(): def net(self, input, class_dim=1000): # student student = MobileNetV3_large_x1_0() diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 986d5ac787c75dc34340c82ce86b0c8e31530664..673e54304b84bd962486e5bdc61b4ddcc1fa511d 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -118,7 +118,10 @@ def init_model(config, program, exe): pretrained_model = config.get('pretrained_model') if pretrained_model: - load_params(exe, program, pretrained_model) + if not isinstance(pretrained_model, list): + pretrained_model = [pretrained_model] + for pretrain in pretrained_model: + load_params(exe, program, pretrain) logger.info("Finish initing model from {}".format(pretrained_model))