train.py 12.7 KB
Newer Older
I
itminner 已提交
1 2 3 4 5 6 7 8 9 10
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
W
whs 已提交
11 12
sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
I
itminner 已提交
13 14
from paddleslim.common import get_logger
from paddleslim.analysis import flops
15
from paddleslim.quant import quant_aware, convert
I
itminner 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28
import models
from utility import add_arguments, print_arguments

quantization_model_save_dir = './quantization_models/'

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',       int,  64 * 4,                 "Minibatch size.")
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model',            str,  "MobileNet",                "The target model.")
I
itminner 已提交
29 30
add_arg('pretrained_model', str,  "../pretrained_model/MobileNetV1_pretrained",                "Whether to use pretrained model.")
add_arg('lr',               float,  0.0001,               "The learning rate used to fine-tune pruned model.")
I
itminner 已提交
31 32 33
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
I
itminner 已提交
34
add_arg('num_epochs',       int,  1,               "The number of total epochs.")
I
itminner 已提交
35 36 37
add_arg('total_images',     int,  1281167,               "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file',      str, None,                 "The config file for compression with yaml format.")
I
itminner 已提交
38
add_arg('data',             str, "imagenet",             "Which data to use. 'mnist' or 'imagenet'")
I
itminner 已提交
39
add_arg('log_period',       int, 10,                 "Log period in batches.")
L
Liufang Sang 已提交
40
add_arg('checkpoint_dir',         str, "output",           "checkpoint save dir")
I
itminner 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
    step = int(math.ceil(float(args.total_images) / args.batch_size))
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
    learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
    optimizer = fluid.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        regularization=fluid.regularizer.L2Decay(args.l2_decay))
    return optimizer


def cosine_decay(args):
    step = int(math.ceil(float(args.total_images) / args.batch_size))
    learning_rate = fluid.layers.cosine_decay(
        learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
    optimizer = fluid.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        regularization=fluid.regularizer.L2Decay(args.l2_decay))
    return optimizer


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
    ############################################################################################################
    # 1. quantization configs
    ############################################################################################################
    quant_config = {
81 82 83
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
I
itminner 已提交
84 85 86 87 88
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
89
        # ops of name_scope in not_quant_pattern list, will not be quantized
I
itminner 已提交
90
        'not_quant_pattern': ['skip_quant'],
91
        # ops of type in quantize_op_types, will be quantized
I
itminner 已提交
92
        'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
93
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
I
itminner 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. defaulf is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
    }

    train_reader = None
    test_reader = None
    if args.data == "mnist":
        import paddle.dataset.mnist as reader
        train_reader = reader.train()
        val_reader = reader.test()
        class_dim = 10
        image_shape = "1,28,28"
    elif args.data == "imagenet":
        import imagenet_reader as reader
        train_reader = reader.train()
        val_reader = reader.val()
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    image_shape = [int(m) for m in image_shape.split(",")]
119 120
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
I
itminner 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
    cost = fluid.layers.cross_entropy(input=out, label=label)
    avg_cost = fluid.layers.mean(x=cost)
    acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
    acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)

    train_prog = fluid.default_main_program()
    val_program = fluid.default_main_program().clone(for_test=True)

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    ############################################################################################################
    # 2. quantization transform programs (training aware)
    #    Make some quantization transforms in the graph before training and testing.
    #    According to the weight and activation quantization type, the graph will be added
    #    some fake quantize operators and fake dequantize operators.
    ############################################################################################################
141 142 143 144
    val_program = quant_aware(
        val_program, place, quant_config, scope=None, for_test=True)
    compiled_train_prog = quant_aware(
        train_prog, place, quant_config, scope=None, for_test=False)
I
itminner 已提交
145 146 147 148 149 150
    opt = create_optimizer(args)
    opt.minimize(avg_cost)

    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

L
Liufang Sang 已提交
151 152 153
    assert os.path.exists(
        args.pretrained_model), "pretrained_model doesn't exist"

I
itminner 已提交
154 155 156
    if args.pretrained_model:

        def if_exist(var):
157
            return os.path.exists(os.path.join(args.pretrained_model, var.name))
I
itminner 已提交
158 159 160

        fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)

161 162
    val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
    train_reader = paddle.fluid.io.batch(
I
itminner 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        train_reader, batch_size=args.batch_size, drop_last=True)

    train_feeder = feeder = fluid.DataFeeder([image, label], place)
    val_feeder = feeder = fluid.DataFeeder(
        [image, label], place, program=val_program)

    def test(epoch, program):
        batch_id = 0
        acc_top1_ns = []
        acc_top5_ns = []
        for data in val_reader():
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
                program,
                feed=train_feeder.feed(data),
                fetch_list=[acc_top1.name, acc_top5.name])
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))
            batch_id += 1

190 191 192
        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
I
itminner 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        return np.mean(np.array(acc_top1_ns))

    def train(epoch, compiled_train_prog):

        batch_id = 0
        for data in train_reader():
            start_time = time.time()
            loss_n, acc_top1_n, acc_top5_n = exe.run(
                compiled_train_prog,
                feed=train_feeder.feed(data),
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            end_time = time.time()
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
                           end_time - start_time))
            batch_id += 1

L
Liufang Sang 已提交
215 216 217 218 219 220 221 222 223 224 225
    build_strategy = fluid.BuildStrategy()
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = False
    build_strategy.fuse_all_reduce_ops = False
    build_strategy.sync_batch_norm = False
    exec_strategy = fluid.ExecutionStrategy()
    compiled_train_prog = compiled_train_prog.with_data_parallel(
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

I
itminner 已提交
226 227 228
    ############################################################################################################
    # train loop
    ############################################################################################################
L
Liufang Sang 已提交
229 230
    best_acc1 = 0.0
    best_epoch = 0
I
itminner 已提交
231 232
    for i in range(args.num_epochs):
        train(i, compiled_train_prog)
L
Liufang Sang 已提交
233 234 235 236 237 238 239 240 241 242 243 244
        acc1 = test(i, val_program)
        fluid.io.save_persistables(
            exe,
            dirname=os.path.join(args.checkpoint_dir, str(i)),
            main_program=val_program)
        if acc1 > best_acc1:
            best_acc1 = acc1
            best_epoch = i
            fluid.io.save_persistables(
                exe,
                dirname=os.path.join(args.checkpoint_dir, 'best_model'),
                main_program=val_program)
L
Liufang Sang 已提交
245 246 247 248 249
    if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
        fluid.io.load_persistables(
            exe,
            dirname=os.path.join(args.checkpoint_dir, 'best_model'),
            main_program=val_program)
I
itminner 已提交
250 251 252 253 254
    ############################################################################################################
    # 3. Freeze the graph after training by adjusting the quantize
    #    operators' order for the inference.
    #    The dtype of float_program's weights is float32, but in int8 range.
    ############################################################################################################
I
itminner 已提交
255 256
    float_program, int8_program = convert(val_program, place, quant_config, \
                                                        scope=None, \
I
itminner 已提交
257
                                                        save_int8=True)
L
Liufang Sang 已提交
258 259
    print("eval best_model after convert")
    final_acc1 = test(best_epoch, float_program)
I
itminner 已提交
260 261 262 263
    ############################################################################################################
    # 4. Save inference model
    ############################################################################################################
    model_path = os.path.join(quantization_model_save_dir, args.model,
264 265
                              'act_' + quant_config['activation_quantize_type']
                              + '_w_' + quant_config['weight_quantize_type'])
I
itminner 已提交
266 267 268 269 270 271 272 273
    float_path = os.path.join(model_path, 'float')
    int8_path = os.path.join(model_path, 'int8')
    if not os.path.isdir(model_path):
        os.makedirs(model_path)

    fluid.io.save_inference_model(
        dirname=float_path,
        feeded_var_names=[image.name],
274 275
        target_vars=[out],
        executor=exe,
I
itminner 已提交
276 277 278 279 280 281 282
        main_program=float_program,
        model_filename=float_path + '/model',
        params_filename=float_path + '/params')

    fluid.io.save_inference_model(
        dirname=int8_path,
        feeded_var_names=[image.name],
283 284
        target_vars=[out],
        executor=exe,
I
itminner 已提交
285 286 287 288 289 290
        main_program=int8_program,
        model_filename=int8_path + '/model',
        params_filename=int8_path + '/params')


def main():
291
    paddle.enable_static()
I
itminner 已提交
292 293 294 295 296 297 298
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()