diff --git a/PaddleSlim/classification/pruning/README.md b/PaddleSlim/classification/pruning/README.md index e8d29118968666a420ba7055fac7e2aaaa2ee134..6f8243b513e7123bdd5989dc42baa57e4f7c6c1e 100644 --- a/PaddleSlim/classification/pruning/README.md +++ b/PaddleSlim/classification/pruning/README.md @@ -130,7 +130,13 @@ fc10_weights (1280L, 1000L) |-30%|- |- |- |-| |-50%|- |- |- |-| ->训练超参: +>训练超参 +batch size: 256 +lr_strategy: piecewise_decay +step_epochs: 30, 60, 90 +num_epochs: 120 +l2_decay: 3e-5 +lr: 0.1 ### MobileNetV2 @@ -142,6 +148,12 @@ fc10_weights (1280L, 1000L) |-50%|- |- |- |-| >训练超参: +batch size: 500 +lr_strategy: cosine_decay +num_epochs: 240 +l2_decay: 4e-5 +lr: 0.1 + ### ResNet50 @@ -152,6 +164,11 @@ fc10_weights (1280L, 1000L) |-30%|- |- |- |-| |-50%|- |- |- |-| ->训练超参: +>训练超参 +batch size: 256 +lr_strategy: cosine_decay +num_epochs: 120 +l2_decay: 1e-4 +lr: 0.1 ## FAQ diff --git a/PaddleSlim/classification/pruning/compress.py b/PaddleSlim/classification/pruning/compress.py index b40d8a1c73d80588a063f27c6819043d543f14d7..77f4f83aaf39e2afe7c214d7c558f6992c0218b4 100644 --- a/PaddleSlim/classification/pruning/compress.py +++ b/PaddleSlim/classification/pruning/compress.py @@ -4,6 +4,7 @@ import logging import paddle import argparse import functools +import math import paddle.fluid as fluid sys.path.append("..") import imagenet_reader as reader @@ -24,12 +25,48 @@ add_arg('batch_size', int, 64*4, "Minibatch size.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model', str, None, "The target model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.") -add_arg('config_file', str, None, "The config file for compression with yaml format.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('config_file', str, None, "The config file for compression with yaml format.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m] +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, + step_each_epoch=step, + epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + def compress(args): class_dim=1000 image_shape="3,224,224" @@ -45,25 +82,14 @@ def compress(args): acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) val_program = fluid.default_main_program().clone() -# for param in fluid.default_main_program().global_block().all_parameters(): -# print param.name, param.shape -# return - opt = fluid.optimizer.Momentum( - momentum=0.9, - learning_rate=fluid.layers.piecewise_decay( - boundaries=[5000 * 30, 5000 * 60, 5000 * 90], - values=[0.1, 0.01, 0.001, 0.0001]), - regularization=fluid.regularizer.L2Decay(4e-5)) - + opt = create_optimizer(args) place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) if args.pretrained_model: - def if_exist(var): return os.path.exists(os.path.join(args.pretrained_model, var.name)) - fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) val_reader = paddle.batch(reader.val(), batch_size=args.batch_size) diff --git a/PaddleSlim/classification/pruning/configs/mobilenet_v1.yaml b/PaddleSlim/classification/pruning/configs/mobilenet_v1.yaml index de0b427f95d06d0f84ed958525e23ce82033174b..2aa857c7123b239e2896b2873fbd6adb21d355ac 100644 --- a/PaddleSlim/classification/pruning/configs/mobilenet_v1.yaml +++ b/PaddleSlim/classification/pruning/configs/mobilenet_v1.yaml @@ -14,7 +14,7 @@ strategies: target_ratio: 0.5 pruned_params: '.*_sep_weights' compressor: - epoch: 3 + epoch: 121 checkpoint_path: './checkpoints/mobilenet_v1/' strategies: - uniform_pruning_strategy diff --git a/PaddleSlim/classification/pruning/configs/mobilenet_v2.yaml b/PaddleSlim/classification/pruning/configs/mobilenet_v2.yaml index 4df7edccab91e5be99f75e48f20cadd953533c8a..8fb8e16d75cfcccf5590f2d46f332726dc34376e 100644 --- a/PaddleSlim/classification/pruning/configs/mobilenet_v2.yaml +++ b/PaddleSlim/classification/pruning/configs/mobilenet_v2.yaml @@ -16,7 +16,7 @@ strategies: # pruned_params: '.*linear_weights' # pruned_params: '.*expand_weights' compressor: - epoch: 2 + epoch: 241 checkpoint_path: './checkpoints/' strategies: - uniform_pruning_strategy diff --git a/PaddleSlim/classification/pruning/configs/resnet50.yaml b/PaddleSlim/classification/pruning/configs/resnet34.yaml similarity index 87% rename from PaddleSlim/classification/pruning/configs/resnet50.yaml rename to PaddleSlim/classification/pruning/configs/resnet34.yaml index 2519c4335611346e9f080b989ebcf503ca23b8b8..ba7d1a4f9d9df80cd47c96eca90b05ac1cd2754e 100644 --- a/PaddleSlim/classification/pruning/configs/resnet50.yaml +++ b/PaddleSlim/classification/pruning/configs/resnet34.yaml @@ -14,7 +14,7 @@ strategies: target_ratio: 0.5 pruned_params: '.*branch.*_weights' compressor: - epoch: 4 - checkpoint_path: './checkpoints/resnet50/' + epoch: 121 + checkpoint_path: './checkpoints/resnet34/' strategies: - uniform_pruning_strategy diff --git a/PaddleSlim/classification/pruning/run.sh b/PaddleSlim/classification/pruning/run.sh index 9549ea8c17c9d303f9293d30c3c0ec46bf7685d5..db02e362bd8b959c36b8eb4ab30a1de5501c5c28 100644 --- a/PaddleSlim/classification/pruning/run.sh +++ b/PaddleSlim/classification/pruning/run.sh @@ -6,7 +6,7 @@ export CUDA_VISIBLE_DEVICES=0 root_url="http://paddle-imagenet-models-name.bj.bcebos.com" MobileNetV1="MobileNetV1_pretrained.tar" MobileNetV2="MobileNetV2_pretrained.tar" -ResNet50="ResNet50_pretrained.tar" +ResNet34="ResNet34_pretrained.tar" pretrain_dir='../pretrain' if [ ! -d ${pretrain_dir} ]; then @@ -25,9 +25,9 @@ if [ ! -f ${MobileNetV2} ]; then tar xf ${MobileNetV2} fi -if [ ! -f ${ResNet50} ]; then - wget ${root_url}/${ResNet50} - tar xf ${ResNet50} +if [ ! -f ${ResNet34} ]; then + wget ${root_url}/${ResNet34} + tar xf ${ResNet34} fi cd - @@ -36,6 +36,11 @@ nohup python -u compress.py \ --model "MobileNet" \ --use_gpu 1 \ --batch_size 256 \ +--total_images 1281167 \ +--lr_strategy "piecewise_decay" \ +--num_epochs 120 \ +--lr 0.1 \ +--l2_decay 3e-5 \ --pretrained_model ../pretrain/MobileNetV1_pretrained \ --config_file "./configs/mobilenet_v1.yaml" \ > mobilenet_v1.log 2>&1 & @@ -46,18 +51,28 @@ tailf mobilenet_v1.log #--model "MobileNetV2" \ #--use_gpu 1 \ #--batch_size 256 \ +#--total_images 1281167 \ +#--lr_strategy "cosine_decay" \ +#--num_epochs 240 \ +#--lr 0.1 \ +#--l2_decay 4e-5 \ #--pretrained_model ../pretrain/MobileNetV2_pretrained \ #--config_file "./configs/mobilenet_v2.yaml" \ #> mobilenet_v2.log 2>&1 & #tailf mobilenet_v2.log -## for compression of resnet50 +## for compression of resnet34 #python -u compress.py \ -#--model "ResNet50" \ +#--model "ResNet34" \ #--use_gpu 1 \ #--batch_size 256 \ -#--pretrained_model ../pretrain/ResNet50_pretrained \ -#--config_file "./configs/resnet50.yaml" \ -#> resnet50.log 2>&1 & -#tailf resnet50.log +#--total_images 1281167 \ +#--lr_strategy "cosine_decay" \ +#--lr 0.1 \ +#--num_epochs 120 \ +#--l2_decay 1e-4 \ +#--pretrained_model ../pretrain/ResNet34_pretrained \ +#--config_file "./configs/resnet34.yaml" \ +#> resnet34.log 2>&1 & +#tailf resnet34.log