train.py 16.9 KB
Newer Older
1 2 3 4 5 6 7 8 9
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
G
Guanghua Yu 已提交
10
import random
11 12
from collections import defaultdict

B
Bai Yifan 已提交
13 14 15
sys.path.append(os.path.dirname("__file__"))
sys.path.append(
    os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
16
from paddleslim.common import get_logger, VarCollector
17 18 19 20
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
21
from paddle.fluid.layer_helper import LayerHelper
22 23 24 25 26 27 28
quantization_model_save_dir = './quantization_models/'

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
29
add_arg('batch_size',       int,  128,
30 31 32
        "Minibatch size.")
add_arg('use_gpu',          bool, True,
        "Whether to use GPU or not.")
33
add_arg('model',            str,  "MobileNetV3_large_x1_0",
34
        "The target model.")
35
add_arg('pretrained_model', str,  "./pretrain/MobileNetV3_large_x1_0_ssld_pretrained",
36
        "Whether to use pretrained model.")
37
add_arg('lr',               float,  0.001,
38 39 40
        "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",
        "The learning rate decay strategy.")
41
add_arg('l2_decay',         float,  1e-5,
42 43 44
        "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,
        "The value of momentum_rate.")
45
add_arg('num_epochs',       int,  30,
46 47 48 49
        "The number of total epochs.")
add_arg('total_images',     int,  1281167,
        "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int,
50
        default=[20],
51 52 53 54 55 56 57
        help="piecewise decay step")
add_arg('config_file',      str, None,
        "The config file for compression with yaml format.")
add_arg('data',             str, "imagenet",
        "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period',       int, 10,
        "Log period in batches.")
58 59 60 61 62 63
add_arg('checkpoint_dir',         str, None,
        "checkpoint dir")
add_arg('checkpoint_epoch',         int, None,
        "checkpoint epoch")
add_arg('output_dir',         str, "output/MobileNetV3_large_x1_0",
        "model save dir")
64 65
add_arg('use_pact',          bool, True,
        "Whether to use PACT or not.")
66 67
add_arg('analysis',          bool, False,
        "Whether analysis variables distribution.")
68 69
add_arg('onnx_format',          bool, False,
        "Whether use onnx format or not.")
G
Guanghua Yu 已提交
70
add_arg('ce_test',                 bool,   False,       "Whether to CE test.")
71 72 73 74 75 76 77

# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
B
Bai Yifan 已提交
78 79
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
80 81
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
82 83
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
B
Bai Yifan 已提交
84 85 86
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
87 88
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
89
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
90
    return learning_rate, optimizer
91 92 93


def cosine_decay(args):
B
Bai Yifan 已提交
94 95
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
96 97
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
B
Bai Yifan 已提交
98 99 100
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
101 102
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
103
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
104
    return learning_rate, optimizer
105 106 107 108 109 110 111 112 113


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


114 115 116 117 118 119 120 121
def _prepare_envs():
    devices = paddle.device.get_device().split(':')[0]
    places = paddle.device._convert_to_place(devices)
    _logger.info(f"devices: {devices}")
    exe = paddle.static.Executor(places)
    return exe, places


122
def compress(args):
G
Guanghua Yu 已提交
123 124 125 126 127 128 129 130 131 132
    num_workers = 4
    shuffle = True
    if args.ce_test:
        # set seed
        seed = 111
        paddle.seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        num_workers = 0
        shuffle = False
133 134

    if args.data == "mnist":
B
Bai Yifan 已提交
135 136
        train_dataset = paddle.vision.datasets.MNIST(mode='train')
        val_dataset = paddle.vision.datasets.MNIST(mode='test')
137 138 139 140
        class_dim = 10
        image_shape = "1,28,28"
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
141 142
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
143 144 145 146 147 148
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    image_shape = [int(m) for m in image_shape.split(",")]
B
Bai Yifan 已提交
149 150
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
B
Bai Yifan 已提交
151 152
    image = paddle.static.data(
        name='image', shape=[None] + image_shape, dtype='float32')
153 154
    if args.use_pact:
        image.stop_gradient = False
B
Bai Yifan 已提交
155
    label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
156 157 158
    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
B
Bai Yifan 已提交
159 160 161 162
    cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
    avg_cost = paddle.mean(x=cost)
    acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
    acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
163

B
Bai Yifan 已提交
164 165
    train_prog = paddle.static.default_main_program()
    val_program = paddle.static.default_main_program().clone(for_test=True)
166

167 168 169
    if not args.analysis:
        learning_rate, opt = create_optimizer(args)
        opt.minimize(avg_cost)
170

171
    exe, places = _prepare_envs()
B
Bai Yifan 已提交
172
    exe.run(paddle.static.default_startup_program())
173

B
Bai Yifan 已提交
174 175 176
    train_loader = paddle.io.DataLoader(
        train_dataset,
        places=places,
177
        feed_list=[image, label],
B
Bai Yifan 已提交
178
        drop_last=True,
B
Bai Yifan 已提交
179
        return_list=False,
B
Bai Yifan 已提交
180
        batch_size=args.batch_size,
181
        use_shared_memory=True,
G
Guanghua Yu 已提交
182 183
        shuffle=shuffle,
        num_workers=num_workers)
B
Bai Yifan 已提交
184 185 186

    valid_loader = paddle.io.DataLoader(
        val_dataset,
187
        places=places,
188
        feed_list=[image, label],
B
Bai Yifan 已提交
189
        drop_last=False,
B
Bai Yifan 已提交
190
        return_list=False,
B
Bai Yifan 已提交
191
        batch_size=args.batch_size,
192
        use_shared_memory=True,
B
Bai Yifan 已提交
193
        shuffle=False)
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261

    if args.analysis:
        # get all activations names
        activates = [
            'pool2d_1.tmp_0', 'tmp_35', 'batch_norm_21.tmp_2', 'tmp_26',
            'elementwise_mul_5.tmp_0', 'pool2d_5.tmp_0',
            'elementwise_add_5.tmp_0', 'relu_2.tmp_0', 'pool2d_3.tmp_0',
            'conv2d_40.tmp_2', 'elementwise_mul_0.tmp_0', 'tmp_62',
            'elementwise_add_8.tmp_0', 'batch_norm_39.tmp_2', 'conv2d_32.tmp_2',
            'tmp_17', 'tmp_5', 'elementwise_add_9.tmp_0', 'pool2d_4.tmp_0',
            'relu_0.tmp_0', 'tmp_53', 'relu_3.tmp_0', 'elementwise_add_4.tmp_0',
            'elementwise_add_6.tmp_0', 'tmp_11', 'conv2d_36.tmp_2',
            'relu_8.tmp_0', 'relu_5.tmp_0', 'pool2d_7.tmp_0',
            'elementwise_add_2.tmp_0', 'elementwise_add_7.tmp_0',
            'pool2d_2.tmp_0', 'tmp_47', 'batch_norm_12.tmp_2',
            'elementwise_mul_6.tmp_0', 'elementwise_mul_7.tmp_0',
            'pool2d_6.tmp_0', 'relu_6.tmp_0', 'elementwise_add_0.tmp_0',
            'elementwise_mul_3.tmp_0', 'conv2d_12.tmp_2',
            'elementwise_mul_2.tmp_0', 'tmp_8', 'tmp_2', 'conv2d_8.tmp_2',
            'elementwise_add_3.tmp_0', 'elementwise_mul_1.tmp_0',
            'pool2d_8.tmp_0', 'conv2d_28.tmp_2', 'image', 'conv2d_16.tmp_2',
            'batch_norm_33.tmp_2', 'relu_1.tmp_0', 'pool2d_0.tmp_0', 'tmp_20',
            'conv2d_44.tmp_2', 'relu_10.tmp_0', 'tmp_41', 'relu_4.tmp_0',
            'elementwise_add_1.tmp_0', 'tmp_23', 'batch_norm_6.tmp_2', 'tmp_29',
            'elementwise_mul_4.tmp_0', 'tmp_14'
        ]
        var_collector = VarCollector(train_prog, activates, use_ema=True)
        values = var_collector.abs_max_run(
            train_loader, exe, step=None, loss_name=avg_cost.name)
        np.save('pact_thres.npy', values)
        _logger.info(values)
        _logger.info("PACT threshold have been saved as pact_thres.npy")

        # Draw Histogram in 'dist_pdf/result.pdf'
        # var_collector.pdf(values)

        return

    values = defaultdict(lambda: 20)
    try:
        values = np.load("pact_thres.npy", allow_pickle=True).item()
        values.update(tmp)
        _logger.info("pact_thres.npy info loaded.")
    except:
        _logger.info(
            "cannot find pact_thres.npy. Set init PACT threshold as 20.")
    _logger.info(values)

    # 1. quantization configs
    quant_config = {
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
        # ops of name_scope in not_quant_pattern list, will not be quantized
        'not_quant_pattern': ['skip_quant'],
        # ops of type in quantize_op_types, will be quantized
        'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. defaulf is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
262 263
        # Whether use onnx format or not
        'onnx_format': args.onnx_format,
264 265
    }

266 267 268 269 270
    # 2. quantization transform programs (training aware)
    #    Make some quantization transforms in the graph before training and testing.
    #    According to the weight and activation quantization type, the graph will be added
    #    some fake quantize operators and fake dequantize operators.

271 272 273 274
    def pact(x):
        helper = LayerHelper("pact", **locals())
        dtype = 'float32'
        init_thres = values[x.name.split('_tmp_input')[0]]
B
Bai Yifan 已提交
275
        u_param_attr = paddle.ParamAttr(
276
            name=x.name + '_pact',
B
Bai Yifan 已提交
277 278
            initializer=paddle.nn.initializer.Constant(value=init_thres),
            regularizer=paddle.regularizer.L2Decay(0.0001),
279 280 281 282
            learning_rate=1)
        u_param = helper.create_parameter(
            attr=u_param_attr, shape=[1], dtype=dtype)

B
Bai Yifan 已提交
283 284
        part_a = paddle.nn.functional.relu(x - u_param)
        part_b = paddle.nn.functional.relu(-u_param - x)
285 286 287 288
        x = x - part_a + part_b
        return x

    def get_optimizer():
B
Bai Yifan 已提交
289
        return paddle.optimizer.Momentum(args.lr, 0.9)
290

291 292 293 294 295 296 297 298 299 300 301
    if args.use_pact:
        act_preprocess_func = pact
        optimizer_func = get_optimizer
        executor = exe
    else:
        act_preprocess_func = None
        optimizer_func = None
        executor = None

    val_program = quant_aware(
        val_program,
302
        places,
303 304
        quant_config,
        scope=None,
305 306 307
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
308 309 310
        for_test=True)
    compiled_train_prog = quant_aware(
        train_prog,
311
        places,
312 313 314 315 316 317 318 319 320 321 322
        quant_config,
        scope=None,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=executor,
        for_test=False)

    assert os.path.exists(
        args.pretrained_model), "pretrained_model doesn't exist"

    if args.pretrained_model:
B
Bai Yifan 已提交
323
        paddle.static.load(train_prog, args.pretrained_model, exe)
324 325 326 327 328

    def test(epoch, program):
        batch_id = 0
        acc_top1_ns = []
        acc_top5_ns = []
329
        for data in valid_loader():
330 331
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
332
                program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
333 334 335
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
336
                    "Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
337 338 339 340 341 342 343
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))
            batch_id += 1

344 345 346 347
        _logger.info(
            "Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
                epoch,
                np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
348 349
        return np.mean(np.array(acc_top1_ns))

B
Bai Yifan 已提交
350
    def train(epoch, compiled_train_prog, lr):
351 352

        batch_id = 0
353
        for data in train_loader():
354
            start_time = time.time()
B
Bai Yifan 已提交
355
            loss_n, acc_top1_n, acc_top5_n = exe.run(
356
                compiled_train_prog,
357
                feed=data,
B
Bai Yifan 已提交
358
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
359

360 361 362 363 364 365
            end_time = time.time()
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
366
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
B
Bai Yifan 已提交
367 368
                    format(epoch, batch_id,
                           learning_rate.get_lr(), loss_n, acc_top1_n,
369
                           acc_top5_n, end_time - start_time))
370 371 372 373 374

            if args.use_pact and batch_id % 1000 == 0:
                threshold = {}
                for var in val_program.list_vars():
                    if 'pact' in var.name:
B
Bai Yifan 已提交
375 376
                        array = np.array(paddle.static.global_scope().find_var(
                            var.name).get_tensor())
377
                        threshold[var.name] = array[0]
378
                _logger.info(threshold)
379
            batch_id += 1
B
Bai Yifan 已提交
380
            lr.step()
381

B
Bai Yifan 已提交
382
    build_strategy = paddle.static.BuildStrategy()
383 384
    build_strategy.enable_inplace = False
    build_strategy.fuse_all_reduce_ops = False
B
Bai Yifan 已提交
385
    exec_strategy = paddle.static.ExecutionStrategy()
386 387 388 389 390 391 392 393
    compiled_train_prog = compiled_train_prog.with_data_parallel(
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    # train loop
    best_acc1 = 0.0
    best_epoch = 0
394 395 396 397 398 399

    start_epoch = 0
    if args.checkpoint_dir is not None:
        ckpt_path = args.checkpoint_dir
        assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set"
        start_epoch = args.checkpoint_epoch
B
Bai Yifan 已提交
400 401
        paddle.static.load(
            executor=exe, model_path=args.checkpoint_dir, program=val_program)
402

403 404
    best_eval_acc1 = 0
    best_acc1_epoch = 0
405
    for i in range(start_epoch, args.num_epochs):
B
Bai Yifan 已提交
406
        train(i, compiled_train_prog, learning_rate)
407
        acc1 = test(i, val_program)
408 409 410 411 412
        if acc1 > best_eval_acc1:
            best_eval_acc1 = acc1
            best_acc1_epoch = i
        _logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format(
            best_eval_acc1, best_acc1_epoch))
B
Bai Yifan 已提交
413
        paddle.static.save(
B
Bai Yifan 已提交
414 415
            model_path=os.path.join(args.output_dir, str(i)),
            program=val_program)
416 417 418
        if acc1 > best_acc1:
            best_acc1 = acc1
            best_epoch = i
B
Bai Yifan 已提交
419
            paddle.static.save(
B
Bai Yifan 已提交
420 421
                model_path=os.path.join(args.output_dir, 'best_model'),
                program=val_program)
422

B
Bai Yifan 已提交
423
    if os.path.exists(os.path.join(args.output_dir, 'best_model.pdparams')):
B
Bai Yifan 已提交
424
        paddle.static.load(
B
Bai Yifan 已提交
425 426 427
            executor=exe,
            model_path=os.path.join(args.output_dir, 'best_model'),
            program=val_program)
428

429 430 431
    # 3. Freeze the graph after training by adjusting the quantize
    #    operators' order for the inference.
    #    The dtype of float_program's weights is float32, but in int8 range.
432 433 434 435
    model_path = os.path.join(quantization_model_save_dir, args.model)
    if not os.path.isdir(model_path):
        os.makedirs(model_path)
    float_program = convert(val_program, places, quant_config)
436
    _logger.info("eval best_model after convert")
437
    final_acc1 = test(best_epoch, float_program)
438 439
    _logger.info("final acc:{}".format(final_acc1))

440
    # 4. Save inference model
W
whs 已提交
441 442 443 444
    paddle.static.save_inference_model(
        os.path.join(model_path, 'model'), [image], [out],
        exe,
        program=float_program)
445 446 447


def main():
448
    paddle.enable_static()
449 450 451 452 453 454 455
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()