train.py 11.2 KB
Newer Older
1 2 3 4 5 6 7
import os
import sys
import logging
import paddle
import argparse
import functools
import math
W
whs 已提交
8
import random
9 10
import time
import numpy as np
W
whs 已提交
11
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
W
whs 已提交
12
from paddleslim.prune import Pruner, save_model
13 14 15 16
from paddleslim.common import get_logger
from paddleslim.analysis import flops
import models
from utility import add_arguments, print_arguments
W
whs 已提交
17
import paddle.vision.transforms as T
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',       int,  64 * 4,                 "Minibatch size.")
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model',            str,  "MobileNet",                "The target model.")
add_arg('pretrained_model', str,  "../pretrained_model/MobileNetV1_pretained",                "Whether to use pretrained model.")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file',      str, None,                 "The config file for compression with yaml format.")
W
whs 已提交
35
add_arg('data',             str, "cifar10",                 "Which data to use. 'cifar10' or 'imagenet'")
36 37
add_arg('log_period',       int, 10,                 "Log period in batches.")
add_arg('test_period',      int, 10,                 "Test period in epoches.")
38
add_arg('model_path',       str, "./models",         "The path to save model.")
W
whs 已提交
39
add_arg('pruned_ratio',     float, None,         "The ratios to be pruned.")
40
add_arg('criterion',        str, "l1_norm",         "The prune criterion to be used, support l1_norm and batch_norm_scale.")
41
add_arg('save_inference',   bool, False,                "Whether to save inference model.")
W
whs 已提交
42
add_arg('ce_test',          bool, False, "Whether to CE test.")
43 44
# yapf: enable

45
model_list = models.__all__
46 47


W
whs 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
def get_pruned_params(args, program):
    params = []
    if args.model == "MobileNet":
        for param in program.global_block().all_parameters():
            if "_sep_weights" in param.name:
                params.append(param.name)
    elif args.model == "MobileNetV2":
        for param in program.global_block().all_parameters():
            if "linear_weights" in param.name or "expand_weights" in param.name:
                params.append(param.name)
    elif args.model == "ResNet34":
        for param in program.global_block().all_parameters():
            if "weights" in param.name and "branch" in param.name:
                params.append(param.name)
    elif args.model == "PVANet":
        for param in program.global_block().all_parameters():
            if "conv_weights" in param.name:
                params.append(param.name)
    return params


Y
yukavio 已提交
69 70
def piecewise_decay(args, step_per_epoch):
    bd = [step_per_epoch * e for e in args.step_epochs]
71
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
Y
yukavio 已提交
72
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
W
whs 已提交
73

Y
yukavio 已提交
74
    optimizer = paddle.optimizer.Momentum(
75 76
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
Y
yukavio 已提交
77
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
Y
yukavio 已提交
78
    return optimizer, learning_rate
79 80


Y
yukavio 已提交
81
def cosine_decay(args, step_per_epoch):
Y
yukavio 已提交
82
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
Y
yukavio 已提交
83
        learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
Y
yukavio 已提交
84
    optimizer = paddle.optimizer.Momentum(
85 86
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
Y
yukavio 已提交
87
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
Y
yukavio 已提交
88
    return optimizer, learning_rate
89 90


Y
yukavio 已提交
91
def create_optimizer(args, step_per_epoch):
92
    if args.lr_strategy == "piecewise_decay":
Y
yukavio 已提交
93
        return piecewise_decay(args, step_per_epoch)
94
    elif args.lr_strategy == "cosine_decay":
Y
yukavio 已提交
95
        return cosine_decay(args, step_per_epoch)
96 97 98


def compress(args):
W
whs 已提交
99 100 101 102 103 104 105 106 107 108 109 110

    num_workers = 4
    shuffle = True
    if args.ce_test:
        # set seed
        seed = 111
        paddle.seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        num_workers = 0
        shuffle = False

111 112
    train_reader = None
    test_reader = None
W
whs 已提交
113 114 115

    need_pretrain = True
    if args.data == "cifar10":
W
whs 已提交
116
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
W
whs 已提交
117 118 119 120
        train_dataset = paddle.vision.datasets.Cifar10(
            mode="train", backend="cv2", transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(
            mode="test", backend="cv2", transform=transform)
121
        class_dim = 10
W
whs 已提交
122 123
        image_shape = "3, 32, 32"
        need_pretrain = False
124 125
    elif args.data == "imagenet":
        import imagenet_reader as reader
Y
yukavio 已提交
126 127
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
128 129 130 131 132
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]
133 134
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
Y
yukavio 已提交
135 136 137 138
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    place = places[0]
    exe = paddle.static.Executor(place)
Y
yukavio 已提交
139 140 141
    image = paddle.static.data(
        name='image', shape=[None] + image_shape, dtype='float32')
    label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
Y
yukavio 已提交
142
    batch_size_per_card = int(args.batch_size / len(places))
Y
yukavio 已提交
143 144 145
    train_loader = paddle.io.DataLoader(
        train_dataset,
        places=places,
Y
yukavio 已提交
146
        feed_list=[image, label],
Y
yukavio 已提交
147
        drop_last=True,
Y
yukavio 已提交
148
        batch_size=batch_size_per_card,
W
whs 已提交
149
        shuffle=shuffle,
150
        return_list=False,
Y
yukavio 已提交
151
        use_shared_memory=True,
W
whs 已提交
152
        num_workers=num_workers)
Y
yukavio 已提交
153 154 155
    valid_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
Y
yukavio 已提交
156
        feed_list=[image, label],
Y
yukavio 已提交
157
        drop_last=False,
158
        return_list=False,
Y
yukavio 已提交
159
        use_shared_memory=True,
Y
yukavio 已提交
160
        batch_size=batch_size_per_card,
Y
yukavio 已提交
161
        shuffle=False)
Y
yukavio 已提交
162 163 164 165 166
    step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))

    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
W
whs 已提交
167
    label = paddle.reshape(label, [-1, 1])
Y
yukavio 已提交
168 169 170 171 172 173 174 175 176 177
    cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
    avg_cost = paddle.mean(x=cost)
    acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
    acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
    val_program = paddle.static.default_main_program().clone(for_test=True)
    opt, learning_rate = create_optimizer(args, step_per_epoch)
    opt.minimize(avg_cost)

    exe.run(paddle.static.default_startup_program())

W
whs 已提交
178
    if need_pretrain and args.pretrained_model:
Y
yukavio 已提交
179 180 181 182 183 184 185 186

        def if_exist(var):
            return os.path.exists(os.path.join(args.pretrained_model, var.name))

        _logger.info("Load pretrained model from {}".format(
            args.pretrained_model))
        paddle.static.load(paddle.static.default_main_program(),
                           args.pretrained_model, exe)
187 188 189 190

    def test(epoch, program):
        acc_top1_ns = []
        acc_top5_ns = []
Y
yukavio 已提交
191
        for batch_id, data in enumerate(valid_loader):
192 193
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
Y
yukavio 已提交
194
                program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
195 196 197 198 199 200 201 202 203 204
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))

205 206 207
        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
208 209

    def train(epoch, program):
Y
yukavio 已提交
210
        for batch_id, data in enumerate(train_loader):
211 212 213
            start_time = time.time()
            loss_n, acc_top1_n, acc_top5_n = exe.run(
                train_program,
Y
yukavio 已提交
214
                feed=data,
215 216 217 218 219 220 221
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            end_time = time.time()
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
Y
yukavio 已提交
222 223 224 225 226
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           learning_rate.get_lr(), loss_n, acc_top1_n,
                           acc_top5_n, end_time - start_time))
            learning_rate.step()
227 228
            batch_id += 1

229
    test(0, val_program)
Y
yukavio 已提交
230
    params = get_pruned_params(args, paddle.static.default_main_program())
W
whs 已提交
231
    _logger.info("FLOPs before pruning: {}".format(
Y
yukavio 已提交
232
        flops(paddle.static.default_main_program())))
233
    pruner = Pruner(args.criterion)
W
whs 已提交
234
    pruned_val_program, _, _ = pruner.prune(
235
        val_program,
Y
yukavio 已提交
236
        paddle.static.global_scope(),
237
        params=params,
W
whs 已提交
238
        ratios=[args.pruned_ratio] * len(params),
239 240 241
        place=place,
        only_graph=True)

W
whs 已提交
242
    pruned_program, _, _ = pruner.prune(
Y
yukavio 已提交
243 244
        paddle.static.default_main_program(),
        paddle.static.global_scope(),
245
        params=params,
W
whs 已提交
246
        ratios=[args.pruned_ratio] * len(params),
247
        place=place)
W
whs 已提交
248
    _logger.info("FLOPs after pruning: {}".format(flops(pruned_program)))
W
whs 已提交
249

Y
yukavio 已提交
250 251 252 253 254 255 256 257
    build_strategy = paddle.static.BuildStrategy()
    exec_strategy = paddle.static.ExecutionStrategy()
    train_program = paddle.static.CompiledProgram(
        pruned_program).with_data_parallel(
            loss_name=avg_cost.name,
            build_strategy=build_strategy,
            exec_strategy=exec_strategy)

258
    for i in range(args.num_epochs):
Y
yukavio 已提交
259
        train(i, train_program)
Y
yukavio 已提交
260
        if (i + 1) % args.test_period == 0:
261
            test(i, pruned_val_program)
W
whs 已提交
262
            save_model(exe, pruned_val_program,
263
                       os.path.join(args.model_path, str(i)))
264 265 266
        if args.save_inference:
            infer_model_path = os.path.join(args.model_path, "infer_models",
                                            str(i))
267
            paddle.static.save_inference_model(
Y
yukavio 已提交
268
                infer_model_path, [image], [out],
269 270
                exe,
                program=pruned_val_program)
271 272
            _logger.info("Saved inference model into [{}]".format(
                infer_model_path))
273 274 275


def main():
276
    paddle.enable_static()
277 278 279 280 281 282 283
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()