train.py 14.9 KB
Newer Older
M
minghaoBD 已提交
1 2 3 4 5 6 7 8 9
import os
import sys
import logging
import paddle
import argparse
import functools
import time
import numpy as np
import paddle.fluid as fluid
10
from paddleslim.prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
M
minghaoBD 已提交
11 12 13 14 15
from paddleslim.common import get_logger
sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir))
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
16 17
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
M
minghaoBD 已提交
18 19 20 21 22 23

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
24 25
add_arg('batch_size',       int,  64,                 "Minibatch size. Default: 64")
add_arg('batch_size_for_validation',       int,  64,                 "Minibatch size for validation. Default: 64")
M
minghaoBD 已提交
26
add_arg('model',            str,  "MobileNet",                "The target model.")
27 28 29 30 31 32 33 34 35 36 37
add_arg('pretrained_model', str,  None,                "Whether to use pretrained model. Default: None")
add_arg('checkpoint',       str, None, "The model to load for resuming training. Default: None")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model. Default: 0.1")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy. Default: piecewise_decay")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter. Default: 3e-5")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate. Default: 0.9")
add_arg('pruning_strategy', str,    'base',            "The pruning strategy, currently we support base and gmp. Default: base")
add_arg('threshold',        float,  0.01,               "The threshold to set zeros, the abs(weights) lower than which will be zeros. Default: 0.01")
add_arg('pruning_mode',            str,  'ratio',               "the pruning mode: whether by ratio or by threshold. Default: ratio")
add_arg('ratio',            float,  0.55,               "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55")
add_arg('num_epochs',       int,  120,               "The number of total epochs. Default: 120")
M
minghaoBD 已提交
38
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
39 40 41 42 43 44 45 46 47 48 49 50
add_arg('data',             str, "imagenet",                 "Which data to use. 'mnist' or 'imagenet'. Default: imagenet")
add_arg('log_period',       int, 100,                 "Log period in batches. Default: 100")
add_arg('test_period',      int, 5,                 "Test period in epoches. Default: 5")
add_arg('model_path',       str, "./models",         "The path to save model. Default: ./models")
add_arg('model_period',     int, 10,             "The period to save model in epochs. Default: 10")
add_arg('last_epoch',     int, -1,             "The last epoch we could train from. Default: -1")
add_arg('stable_epochs',    int, 0,              "The epoch numbers used to stablize the model before pruning. Default: 0")
add_arg('pruning_epochs',   int, 60,             "The epoch numbers used to prune the model by a ratio step. Default: 60")
add_arg('tunning_epochs',   int, 60,             "The epoch numbers used to tune the after-pruned models. Default: 60")
add_arg('pruning_steps',    int, 120,        "How many times you want to increase your ratio during training. Default: 120")
add_arg('initial_ratio',    float, 0.15,         "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('prune_params_type', str, None,           "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
M
minghaoBD 已提交
51 52 53 54 55 56 57 58
# yapf: enable

model_list = models.__all__


def piecewise_decay(args, step_per_epoch):
    bd = [step_per_epoch * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
59 60 61
    last_iter = (1 + args.last_epoch) * step_per_epoch
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, last_epoch=last_iter)
M
minghaoBD 已提交
62 63 64 65 66 67 68 69 70

    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
    return optimizer, learning_rate


def cosine_decay(args, step_per_epoch):
71
    last_iter = (1 + args.last_epoch) * step_per_epoch
M
minghaoBD 已提交
72
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
73 74 75
        learning_rate=args.lr,
        T_max=args.num_epochs * step_per_epoch,
        last_epoch=last_iter)
M
minghaoBD 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89
    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
    return optimizer, learning_rate


def create_optimizer(args, step_per_epoch):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args, step_per_epoch)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args, step_per_epoch)


90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
def create_unstructured_pruner(train_program, args, place, configs):
    if configs is None:
        return UnstructuredPruner(
            train_program,
            mode=args.pruning_mode,
            ratio=args.ratio,
            threshold=args.threshold,
            prune_params_type=args.prune_params_type,
            place=place)
    else:
        return GMPUnstructuredPruner(
            train_program,
            ratio=args.ratio,
            prune_params_type=args.prune_params_type,
            place=place,
            configs=configs)


M
minghaoBD 已提交
108
def compress(args):
109 110 111 112 113 114 115 116 117
    env = os.environ
    num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
    use_data_parallel = num_trainers > 1

    if use_data_parallel:
        # Fleet step 1: initialize the distributed environment
        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)

M
minghaoBD 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130
    train_reader = None
    test_reader = None
    if args.data == "mnist":
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
        train_dataset = paddle.vision.datasets.MNIST(
            mode='train', backend="cv2", transform=transform)
        val_dataset = paddle.vision.datasets.MNIST(
            mode='test', backend="cv2", transform=transform)
        class_dim = 10
        image_shape = "1,28,28"
        args.pretrained_model = False
    elif args.data == "imagenet":
        import imagenet_reader as reader
131 132
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
M
minghaoBD 已提交
133 134 135 136 137 138 139
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
140
    places = paddle.static.cuda_places()
M
minghaoBD 已提交
141 142
    place = places[0]
    exe = paddle.static.Executor(place)
143

M
minghaoBD 已提交
144 145 146
    image = paddle.static.data(
        name='image', shape=[None] + image_shape, dtype='float32')
    label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
147 148 149

    batch_size_per_card = args.batch_size
    batch_sampler = paddle.io.DistributedBatchSampler(
M
minghaoBD 已提交
150 151 152
        train_dataset,
        batch_size=batch_size_per_card,
        shuffle=True,
153 154 155 156 157 158 159
        drop_last=True)

    train_loader = paddle.io.DataLoader(
        train_dataset,
        places=place,
        batch_sampler=batch_sampler,
        feed_list=[image, label],
M
minghaoBD 已提交
160 161 162
        return_list=False,
        use_shared_memory=True,
        num_workers=32)
163

M
minghaoBD 已提交
164 165 166 167 168 169 170
    valid_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
        feed_list=[image, label],
        drop_last=False,
        return_list=False,
        use_shared_memory=True,
M
minghaoBD 已提交
171
        batch_size=args.batch_size_for_validation,
M
minghaoBD 已提交
172
        shuffle=False)
173 174 175

    step_per_epoch = int(
        np.ceil(len(train_dataset) * 1. / args.batch_size / num_trainers))
M
minghaoBD 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188

    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
    cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
    avg_cost = paddle.mean(x=cost)
    acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
    acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)

    val_program = paddle.static.default_main_program().clone(for_test=True)

    opt, learning_rate = create_optimizer(args, step_per_epoch)

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    # Fleet step 2: distributed strategy
    if use_data_parallel:
        dist_strategy = DistributedStrategy()
        dist_strategy.sync_batch_norm = False
        dist_strategy.exec_strategy = paddle.static.ExecutionStrategy()
        dist_strategy.fuse_all_reduce_ops = False

    train_program = paddle.static.default_main_program()

    if args.pruning_strategy == 'gmp':
        # GMP pruner step 0: define configs for GMP, no need to define configs for the base training.
        configs = {
            'stable_iterations': args.stable_epochs * step_per_epoch,
            'pruning_iterations': args.pruning_epochs * step_per_epoch,
            'tunning_iterations': args.tunning_epochs * step_per_epoch,
            'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
            'pruning_steps': args.pruning_steps,
            'initial_ratio': args.initial_ratio,
        }
    elif args.pruning_strategy == 'base':
        configs = None

    # GMP pruner step 1: initialize a pruner object by calling entry function.
    pruner = create_unstructured_pruner(
        train_program, args, place, configs=configs)

    if use_data_parallel:
        # Fleet step 3: decorate the origial optimizer and minimize it
        opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
    opt.minimize(avg_cost, no_grad_set=pruner.no_grad_set)
M
minghaoBD 已提交
219 220

    exe.run(paddle.static.default_startup_program())
221 222 223 224 225
    if args.last_epoch > -1:
        assert args.checkpoint is not None and os.path.exists(
            args.checkpoint), "Please specify a valid checkpoint path."
        paddle.fluid.io.load_persistables(
            executor=exe, dirname=args.checkpoint, main_program=train_program)
M
minghaoBD 已提交
226

227
    elif args.pretrained_model:
228 229 230 231
        assert os.path.exists(
            args.
            pretrained_model), "Pretrained model path {} doesn't exist".format(
                args.pretrained_model)
M
minghaoBD 已提交
232 233 234 235 236 237 238

        def if_exist(var):
            return os.path.exists(os.path.join(args.pretrained_model, var.name))

        _logger.info("Load pretrained model from {}".format(
            args.pretrained_model))
        # NOTE: We are using fluid.io.load_vars() because the pretrained model is from an older version which requires this API. 
239
        # Please consider using paddle.static.load(program, model_path) when possible
M
minghaoBD 已提交
240 241 242 243 244 245 246
        paddle.fluid.io.load_vars(
            exe, args.pretrained_model, predicate=if_exist)

    def test(epoch, program):
        acc_top1_ns = []
        acc_top5_ns = []

247 248 249 250
        _logger.info(
            "The current sparsity of the inference model is {}%".format(
                round(100 * UnstructuredPruner.total_sparse(
                    paddle.static.default_main_program()), 2)))
M
minghaoBD 已提交
251 252 253
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
254
                program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
M
minghaoBD 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))

        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))

    def train(epoch, program):
270 271 272 273
        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()
M
minghaoBD 已提交
274
        for batch_id, data in enumerate(train_loader):
275 276
            train_reader_cost += time.time() - reader_start
            train_start = time.time()
M
minghaoBD 已提交
277
            loss_n, acc_top1_n, acc_top5_n = exe.run(
278
                program,
279
                feed=data,
M
minghaoBD 已提交
280
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
281
            # GMP pruner step 2: step() to update ratios and other internal states of the pruner.
282
            pruner.step()
283

284 285
            train_run_cost += time.time() - train_start
            total_samples += args.batch_size
M
minghaoBD 已提交
286 287 288 289 290
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
291
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
M
minghaoBD 已提交
292 293
                    format(epoch, batch_id,
                           learning_rate.get_lr(), loss_n, acc_top1_n,
294 295 296 297 298 299 300 301
                           acc_top5_n, train_reader_cost / args.log_period, (
                               train_reader_cost + train_run_cost
                           ) / args.log_period, total_samples / args.log_period,
                           total_samples / (train_reader_cost + train_run_cost
                                            )))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
M
minghaoBD 已提交
302
            learning_rate.step()
303
            reader_start = time.time()
M
minghaoBD 已提交
304

305 306 307 308 309 310 311 312 313 314 315 316 317
    if use_data_parallel:
        # Fleet step 4: get the compiled program from fleet
        compiled_train_program = fleet.main_program
    else:
        compiled_train_program = paddle.static.CompiledProgram(
            paddle.static.default_main_program())

    for i in range(args.last_epoch + 1, args.num_epochs):
        train(i, compiled_train_program)
        # GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation. 
        pruner.update_params()

        _logger.info("The current sparsity of the pruned model is: {}%".format(
M
minghaoBD 已提交
318 319 320
            round(100 * UnstructuredPruner.total_sparse(
                paddle.static.default_main_program()), 2)))

321
        if (i + 1) % args.test_period == 0:
M
minghaoBD 已提交
322
            test(i, val_program)
323
        if (i + 1) % args.model_period == 0:
324 325 326 327 328
            if use_data_parallel:
                fleet.save_persistables(executor=exe, dirname=args.model_path)
            else:
                paddle.fluid.io.save_persistables(
                    executor, dirname=args.model_path)
M
minghaoBD 已提交
329 330 331 332 333 334 335 336 337 338 339


def main():
    paddle.enable_static()
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()