train.py 17.2 KB
Newer Older
1 2 3 4 5 6 7 8 9
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
G
Guanghua Yu 已提交
10
import random
11 12
from collections import defaultdict

B
Bai Yifan 已提交
13 14 15
sys.path.append(os.path.dirname("__file__"))
sys.path.append(
    os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
16
from paddleslim.common import get_logger, VarCollector
17 18 19 20
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
21
from paddle.fluid.layer_helper import LayerHelper
22 23 24 25 26 27 28
quantization_model_save_dir = './quantization_models/'

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
29
add_arg('batch_size',       int,  128,
30 31 32
        "Minibatch size.")
add_arg('use_gpu',          bool, True,
        "Whether to use GPU or not.")
33
add_arg('model',            str,  "MobileNetV3_large_x1_0",
34
        "The target model.")
35
add_arg('pretrained_model', str,  "./pretrain/MobileNetV3_large_x1_0_ssld_pretrained",
36
        "Whether to use pretrained model.")
37
add_arg('lr',               float,  0.001,
38 39 40
        "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",
        "The learning rate decay strategy.")
41
add_arg('l2_decay',         float,  1e-5,
42 43 44
        "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,
        "The value of momentum_rate.")
45
add_arg('num_epochs',       int,  30,
46 47 48 49
        "The number of total epochs.")
add_arg('total_images',     int,  1281167,
        "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int,
50
        default=[20],
51 52 53 54 55 56 57
        help="piecewise decay step")
add_arg('config_file',      str, None,
        "The config file for compression with yaml format.")
add_arg('data',             str, "imagenet",
        "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period',       int, 10,
        "Log period in batches.")
58 59 60 61 62 63
add_arg('checkpoint_dir',         str, None,
        "checkpoint dir")
add_arg('checkpoint_epoch',         int, None,
        "checkpoint epoch")
add_arg('output_dir',         str, "output/MobileNetV3_large_x1_0",
        "model save dir")
64 65
add_arg('use_pact',          bool, True,
        "Whether to use PACT or not.")
66 67
add_arg('analysis',          bool, False,
        "Whether analysis variables distribution.")
G
Guanghua Yu 已提交
68
add_arg('ce_test',                 bool,   False,       "Whether to CE test.")
69 70 71 72 73 74 75

# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
B
Bai Yifan 已提交
76 77
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
78 79
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
80 81
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
B
Bai Yifan 已提交
82 83 84
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
85 86
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
87
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
88
    return learning_rate, optimizer
89 90 91


def cosine_decay(args):
B
Bai Yifan 已提交
92 93
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
94 95
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
B
Bai Yifan 已提交
96 97 98
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
99 100
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
101
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
102
    return learning_rate, optimizer
103 104 105 106 107 108 109 110 111


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


112 113 114 115 116 117 118 119
def _prepare_envs():
    devices = paddle.device.get_device().split(':')[0]
    places = paddle.device._convert_to_place(devices)
    _logger.info(f"devices: {devices}")
    exe = paddle.static.Executor(places)
    return exe, places


120
def compress(args):
G
Guanghua Yu 已提交
121 122 123 124 125 126 127 128 129 130
    num_workers = 4
    shuffle = True
    if args.ce_test:
        # set seed
        seed = 111
        paddle.seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        num_workers = 0
        shuffle = False
131 132

    if args.data == "mnist":
B
Bai Yifan 已提交
133 134
        train_dataset = paddle.vision.datasets.MNIST(mode='train')
        val_dataset = paddle.vision.datasets.MNIST(mode='test')
135 136 137 138
        class_dim = 10
        image_shape = "1,28,28"
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
139 140
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
141 142 143 144 145 146
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    image_shape = [int(m) for m in image_shape.split(",")]
B
Bai Yifan 已提交
147 148
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
B
Bai Yifan 已提交
149 150
    image = paddle.static.data(
        name='image', shape=[None] + image_shape, dtype='float32')
151 152
    if args.use_pact:
        image.stop_gradient = False
B
Bai Yifan 已提交
153
    label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
154 155 156
    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
B
Bai Yifan 已提交
157 158 159 160
    cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
    avg_cost = paddle.mean(x=cost)
    acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
    acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
161

B
Bai Yifan 已提交
162 163
    train_prog = paddle.static.default_main_program()
    val_program = paddle.static.default_main_program().clone(for_test=True)
164

165 166 167
    if not args.analysis:
        learning_rate, opt = create_optimizer(args)
        opt.minimize(avg_cost)
168

169
    exe, places = _prepare_envs()
B
Bai Yifan 已提交
170
    exe.run(paddle.static.default_startup_program())
171

B
Bai Yifan 已提交
172 173 174
    train_loader = paddle.io.DataLoader(
        train_dataset,
        places=places,
175
        feed_list=[image, label],
B
Bai Yifan 已提交
176
        drop_last=True,
B
Bai Yifan 已提交
177
        return_list=False,
B
Bai Yifan 已提交
178
        batch_size=args.batch_size,
179
        use_shared_memory=True,
G
Guanghua Yu 已提交
180 181
        shuffle=shuffle,
        num_workers=num_workers)
B
Bai Yifan 已提交
182 183 184

    valid_loader = paddle.io.DataLoader(
        val_dataset,
185
        places=places,
186
        feed_list=[image, label],
B
Bai Yifan 已提交
187
        drop_last=False,
B
Bai Yifan 已提交
188
        return_list=False,
B
Bai Yifan 已提交
189
        batch_size=args.batch_size,
190
        use_shared_memory=True,
B
Bai Yifan 已提交
191
        shuffle=False)
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261

    if args.analysis:
        # get all activations names
        activates = [
            'pool2d_1.tmp_0', 'tmp_35', 'batch_norm_21.tmp_2', 'tmp_26',
            'elementwise_mul_5.tmp_0', 'pool2d_5.tmp_0',
            'elementwise_add_5.tmp_0', 'relu_2.tmp_0', 'pool2d_3.tmp_0',
            'conv2d_40.tmp_2', 'elementwise_mul_0.tmp_0', 'tmp_62',
            'elementwise_add_8.tmp_0', 'batch_norm_39.tmp_2', 'conv2d_32.tmp_2',
            'tmp_17', 'tmp_5', 'elementwise_add_9.tmp_0', 'pool2d_4.tmp_0',
            'relu_0.tmp_0', 'tmp_53', 'relu_3.tmp_0', 'elementwise_add_4.tmp_0',
            'elementwise_add_6.tmp_0', 'tmp_11', 'conv2d_36.tmp_2',
            'relu_8.tmp_0', 'relu_5.tmp_0', 'pool2d_7.tmp_0',
            'elementwise_add_2.tmp_0', 'elementwise_add_7.tmp_0',
            'pool2d_2.tmp_0', 'tmp_47', 'batch_norm_12.tmp_2',
            'elementwise_mul_6.tmp_0', 'elementwise_mul_7.tmp_0',
            'pool2d_6.tmp_0', 'relu_6.tmp_0', 'elementwise_add_0.tmp_0',
            'elementwise_mul_3.tmp_0', 'conv2d_12.tmp_2',
            'elementwise_mul_2.tmp_0', 'tmp_8', 'tmp_2', 'conv2d_8.tmp_2',
            'elementwise_add_3.tmp_0', 'elementwise_mul_1.tmp_0',
            'pool2d_8.tmp_0', 'conv2d_28.tmp_2', 'image', 'conv2d_16.tmp_2',
            'batch_norm_33.tmp_2', 'relu_1.tmp_0', 'pool2d_0.tmp_0', 'tmp_20',
            'conv2d_44.tmp_2', 'relu_10.tmp_0', 'tmp_41', 'relu_4.tmp_0',
            'elementwise_add_1.tmp_0', 'tmp_23', 'batch_norm_6.tmp_2', 'tmp_29',
            'elementwise_mul_4.tmp_0', 'tmp_14'
        ]
        var_collector = VarCollector(train_prog, activates, use_ema=True)
        values = var_collector.abs_max_run(
            train_loader, exe, step=None, loss_name=avg_cost.name)
        np.save('pact_thres.npy', values)
        _logger.info(values)
        _logger.info("PACT threshold have been saved as pact_thres.npy")

        # Draw Histogram in 'dist_pdf/result.pdf'
        # var_collector.pdf(values)

        return

    values = defaultdict(lambda: 20)
    try:
        values = np.load("pact_thres.npy", allow_pickle=True).item()
        values.update(tmp)
        _logger.info("pact_thres.npy info loaded.")
    except:
        _logger.info(
            "cannot find pact_thres.npy. Set init PACT threshold as 20.")
    _logger.info(values)

    # 1. quantization configs
    quant_config = {
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
        # ops of name_scope in not_quant_pattern list, will not be quantized
        'not_quant_pattern': ['skip_quant'],
        # ops of type in quantize_op_types, will be quantized
        'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. defaulf is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
    }

262 263 264 265 266
    # 2. quantization transform programs (training aware)
    #    Make some quantization transforms in the graph before training and testing.
    #    According to the weight and activation quantization type, the graph will be added
    #    some fake quantize operators and fake dequantize operators.

267 268 269 270
    def pact(x):
        helper = LayerHelper("pact", **locals())
        dtype = 'float32'
        init_thres = values[x.name.split('_tmp_input')[0]]
B
Bai Yifan 已提交
271
        u_param_attr = paddle.ParamAttr(
272
            name=x.name + '_pact',
B
Bai Yifan 已提交
273 274
            initializer=paddle.nn.initializer.Constant(value=init_thres),
            regularizer=paddle.regularizer.L2Decay(0.0001),
275 276 277 278
            learning_rate=1)
        u_param = helper.create_parameter(
            attr=u_param_attr, shape=[1], dtype=dtype)

B
Bai Yifan 已提交
279 280
        part_a = paddle.nn.functional.relu(x - u_param)
        part_b = paddle.nn.functional.relu(-u_param - x)
281 282 283 284
        x = x - part_a + part_b
        return x

    def get_optimizer():
B
Bai Yifan 已提交
285
        return paddle.optimizer.Momentum(args.lr, 0.9)
286

287 288 289 290 291 292 293 294 295 296 297
    if args.use_pact:
        act_preprocess_func = pact
        optimizer_func = get_optimizer
        executor = exe
    else:
        act_preprocess_func = None
        optimizer_func = None
        executor = None

    val_program = quant_aware(
        val_program,
298
        places,
299 300 301 302 303 304 305 306
        quant_config,
        scope=None,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=executor,
        for_test=True)
    compiled_train_prog = quant_aware(
        train_prog,
307
        places,
308 309 310 311 312 313 314 315 316 317 318
        quant_config,
        scope=None,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=executor,
        for_test=False)

    assert os.path.exists(
        args.pretrained_model), "pretrained_model doesn't exist"

    if args.pretrained_model:
B
Bai Yifan 已提交
319
        paddle.static.load(train_prog, args.pretrained_model, exe)
320 321 322 323 324

    def test(epoch, program):
        batch_id = 0
        acc_top1_ns = []
        acc_top5_ns = []
325
        for data in valid_loader():
326 327
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
328
                program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
329 330 331
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
332
                    "Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
333 334 335 336 337 338 339
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))
            batch_id += 1

340 341 342 343
        _logger.info(
            "Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
                epoch,
                np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
344 345
        return np.mean(np.array(acc_top1_ns))

B
Bai Yifan 已提交
346
    def train(epoch, compiled_train_prog, lr):
347 348

        batch_id = 0
349
        for data in train_loader():
350
            start_time = time.time()
B
Bai Yifan 已提交
351
            loss_n, acc_top1_n, acc_top5_n = exe.run(
352
                compiled_train_prog,
353
                feed=data,
B
Bai Yifan 已提交
354
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
355

356 357 358 359 360 361
            end_time = time.time()
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
362
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
B
Bai Yifan 已提交
363 364
                    format(epoch, batch_id,
                           learning_rate.get_lr(), loss_n, acc_top1_n,
365
                           acc_top5_n, end_time - start_time))
366 367 368 369 370

            if args.use_pact and batch_id % 1000 == 0:
                threshold = {}
                for var in val_program.list_vars():
                    if 'pact' in var.name:
B
Bai Yifan 已提交
371 372
                        array = np.array(paddle.static.global_scope().find_var(
                            var.name).get_tensor())
373
                        threshold[var.name] = array[0]
374
                _logger.info(threshold)
375
            batch_id += 1
B
Bai Yifan 已提交
376
            lr.step()
377

B
Bai Yifan 已提交
378
    build_strategy = paddle.static.BuildStrategy()
379 380
    build_strategy.enable_inplace = False
    build_strategy.fuse_all_reduce_ops = False
B
Bai Yifan 已提交
381
    exec_strategy = paddle.static.ExecutionStrategy()
382 383 384 385 386 387 388 389
    compiled_train_prog = compiled_train_prog.with_data_parallel(
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    # train loop
    best_acc1 = 0.0
    best_epoch = 0
390 391 392 393 394 395

    start_epoch = 0
    if args.checkpoint_dir is not None:
        ckpt_path = args.checkpoint_dir
        assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set"
        start_epoch = args.checkpoint_epoch
B
Bai Yifan 已提交
396 397
        paddle.static.load(
            executor=exe, model_path=args.checkpoint_dir, program=val_program)
398

399 400
    best_eval_acc1 = 0
    best_acc1_epoch = 0
401
    for i in range(start_epoch, args.num_epochs):
B
Bai Yifan 已提交
402
        train(i, compiled_train_prog, learning_rate)
403
        acc1 = test(i, val_program)
404 405 406 407 408
        if acc1 > best_eval_acc1:
            best_eval_acc1 = acc1
            best_acc1_epoch = i
        _logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format(
            best_eval_acc1, best_acc1_epoch))
B
Bai Yifan 已提交
409
        paddle.static.save(
B
Bai Yifan 已提交
410 411
            model_path=os.path.join(args.output_dir, str(i)),
            program=val_program)
412 413 414
        if acc1 > best_acc1:
            best_acc1 = acc1
            best_epoch = i
B
Bai Yifan 已提交
415
            paddle.static.save(
B
Bai Yifan 已提交
416 417
                model_path=os.path.join(args.output_dir, 'best_model'),
                program=val_program)
418

B
Bai Yifan 已提交
419
    if os.path.exists(os.path.join(args.output_dir, 'best_model.pdparams')):
B
Bai Yifan 已提交
420
        paddle.static.load(
B
Bai Yifan 已提交
421 422 423
            executor=exe,
            model_path=os.path.join(args.output_dir, 'best_model'),
            program=val_program)
424

425 426 427
    # 3. Freeze the graph after training by adjusting the quantize
    #    operators' order for the inference.
    #    The dtype of float_program's weights is float32, but in int8 range.
428
    float_program, int8_program = convert(val_program, places, quant_config, \
429 430
                                                        scope=None, \
                                                        save_int8=True)
431
    _logger.info("eval best_model after convert")
432
    final_acc1 = test(best_epoch, float_program)
433 434
    _logger.info("final acc:{}".format(final_acc1))

435 436 437 438 439 440 441 442
    # 4. Save inference model
    model_path = os.path.join(quantization_model_save_dir, args.model,
                              'act_' + quant_config['activation_quantize_type']
                              + '_w_' + quant_config['weight_quantize_type'])
    float_path = os.path.join(model_path, 'float')
    if not os.path.isdir(model_path):
        os.makedirs(model_path)

443
    paddle.fluid.io.save_inference_model(
444 445 446 447 448 449 450 451 452 453
        dirname=float_path,
        feeded_var_names=[image.name],
        target_vars=[out],
        executor=exe,
        main_program=float_program,
        model_filename=float_path + '/model',
        params_filename=float_path + '/params')


def main():
454
    paddle.enable_static()
455 456 457 458 459 460 461
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()