train.py 7.7 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
12 13
sys.path.append(
    os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
W
whs 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
import paddleslim
from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops
import paddle.vision.models as models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
from paddle.static import InputSpec as Input
from imagenet import ImageNetDataset
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
from paddle.distributed import ParallelEnv

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',       int,  64 * 4,                 "Minibatch size.")
add_arg('model',            str,  "MobileNet",                "The target model.")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data',             str, "mnist",                 "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period',       int, 10,                 "Log period in batches.")
add_arg('test_period',      int, 10,                 "Test period in epoches.")
add_arg('model_path',       str, "./models",         "The path to save model.")
add_arg('pruned_ratio',     float, None,         "The ratios to be pruned.")
add_arg('criterion',        str, "l1_norm",         "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('use_gpu',   bool, True,                "Whether to GPUs.")
add_arg('checkpoint',   str, None,                "The path of checkpoint which is used for resume training.")
# yapf: enable

model_list = models.__all__


def get_pruned_params(args, model):
    params = []
    if args.model == "mobilenet_v1":
        skip_vars = ['linear_0.b_0',
                     'conv2d_0.w_0']  # skip the first conv2d and last linear
        for sublayer in model.sublayers():
            for param in sublayer.parameters(include_sublayers=False):
                if isinstance(
                        sublayer, paddle.nn.Conv2D
                ) and sublayer._groups == 1 and param.name not in skip_vars:
                    params.append(param.name)
    elif args.model == "mobilenet_v2":
        for sublayer in model.sublayers():
            for param in sublayer.parameters(include_sublayers=False):
                if isinstance(sublayer, paddle.nn.Conv2D):
                    params.append(param.name)
        return params
    elif args.model == "resnet34":
        for sublayer in model.sublayers():
            for param in sublayer.parameters(include_sublayers=False):
                if isinstance(sublayer, paddle.nn.Conv2D):
                    params.append(param.name)
        return params
    else:
        raise NotImplementedError(
            "Current demo only support for mobilenet_v1, mobilenet_v2, resnet34")
    return params


def piecewise_decay(args, parameters, steps_per_epoch):
    bd = [steps_per_epoch * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)

    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=parameters)
    return optimizer


def cosine_decay(args, parameters, steps_per_epoch):
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=args.num_epochs * steps_per_epoch)
    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=parameters)
    return optimizer


def create_optimizer(args, parameters, steps_per_epoch):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args, parameters, steps_per_epoch)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args, parameters, steps_per_epoch)


def compress(args):

    paddle.set_device('gpu' if args.use_gpu else 'cpu')
    train_reader = None
    test_reader = None
    if args.data == "cifar10":

        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])

        train_dataset = paddle.vision.datasets.Cifar10(
            mode="train", backend="cv2", transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(
            mode="test", backend="cv2", transform=transform)
        class_dim = 10
        image_shape = [3, 32, 32]
        pretrain = False
    elif args.data == "imagenet":

        train_dataset = ImageNetDataset(
            "data/ILSVRC2012",
            mode='train',
            image_size=224,
            resize_short_size=256)

        val_dataset = ImageNetDataset(
            "data/ILSVRC2012",
            mode='val',
            image_size=224,
            resize_short_size=256)

        class_dim = 1000
        image_shape = [3, 224, 224]
        pretrain = True
    else:
        raise ValueError("{} is not supported.".format(args.data))
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
    inputs = [Input([None] + image_shape, 'float32', name='image')]
    labels = [Input([None, 1], 'int64', name='label')]

    # model definition
    net = models.__dict__[args.model](pretrained=pretrain,
                                      num_classes=class_dim)

    _logger.info("FLOPs before pruning: {}GFLOPs".format(
        flops(net, [1] + image_shape) / 1000))
    net.eval()
    if args.criterion == 'fpgm':
        pruner = paddleslim.dygraph.FPGMFilterPruner(net, [1] + image_shape)
    elif args.criterion == 'l1_norm':
        pruner = paddleslim.dygraph.L1NormFilterPruner(net, [1] + image_shape)

    params = get_pruned_params(args, net)
    ratios = {}
    for param in params:
        ratios[param] = args.pruned_ratio
    plan = pruner.prune_vars(ratios, [0])

    _logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
        flops(net, [1] + image_shape) / 1000, plan.pruned_flops))

    for param in net.parameters():
        if "conv2d" in param.name:
            print(f"{param.name}\t{param.shape}")

    net.train()
    model = paddle.Model(net, inputs, labels)
    steps_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
    opt = create_optimizer(args, net.parameters(), steps_per_epoch)
    model.prepare(
        opt, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 5)))
    if args.checkpoint is not None:
        model.load(args.checkpoint)
    model.fit(train_data=train_dataset,
              eval_data=val_dataset,
              epochs=args.num_epochs,
              batch_size=args.batch_size // ParallelEnv().nranks,
              verbose=1,
              save_dir=args.model_path,
              num_workers=8)


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()