train.py 12.7 KB
Newer Older
I
itminner 已提交
1 2 3 4 5 6 7 8 9 10
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
W
whs 已提交
11 12
sys.path[0] = os.path.join(
    os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
I
itminner 已提交
13 14
from paddleslim.common import get_logger
from paddleslim.analysis import flops
15
from paddleslim.quant import quant_aware, convert
I
itminner 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28
import models
from utility import add_arguments, print_arguments

quantization_model_save_dir = './quantization_models/'

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',       int,  64 * 4,                 "Minibatch size.")
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('model',            str,  "MobileNet",                "The target model.")
I
itminner 已提交
29 30
add_arg('pretrained_model', str,  "../pretrained_model/MobileNetV1_pretrained",                "Whether to use pretrained model.")
add_arg('lr',               float,  0.0001,               "The learning rate used to fine-tune pruned model.")
I
itminner 已提交
31 32 33
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
I
itminner 已提交
34
add_arg('num_epochs',       int,  1,               "The number of total epochs.")
I
itminner 已提交
35 36 37
add_arg('total_images',     int,  1281167,               "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file',      str, None,                 "The config file for compression with yaml format.")
I
itminner 已提交
38
add_arg('data',             str, "imagenet",             "Which data to use. 'mnist' or 'imagenet'")
I
itminner 已提交
39
add_arg('log_period',       int, 10,                 "Log period in batches.")
L
Liufang Sang 已提交
40
add_arg('checkpoint_dir',         str, "output",           "checkpoint save dir")
I
itminner 已提交
41 42 43 44 45 46
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
B
Bai Yifan 已提交
47 48 49 50
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
I
itminner 已提交
51 52
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
B
Bai Yifan 已提交
53 54 55
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
I
itminner 已提交
56 57
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
58
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
I
itminner 已提交
59 60 61 62
    return optimizer


def cosine_decay(args):
B
Bai Yifan 已提交
63 64 65 66 67 68 69
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    step = int(
        math.ceil(float(args.total_images) / (args.batch_size * len(places))))
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
I
itminner 已提交
70 71
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
B
Bai Yifan 已提交
72
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
I
itminner 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    return optimizer


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
    ############################################################################################################
    # 1. quantization configs
    ############################################################################################################
    quant_config = {
88 89 90
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
I
itminner 已提交
91 92 93 94 95
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
96
        # ops of name_scope in not_quant_pattern list, will not be quantized
I
itminner 已提交
97
        'not_quant_pattern': ['skip_quant'],
98
        # ops of type in quantize_op_types, will be quantized
I
itminner 已提交
99
        'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
100
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
I
itminner 已提交
101 102 103 104 105 106 107 108
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. defaulf is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
    }

    if args.data == "mnist":
B
Bai Yifan 已提交
109 110
        train_dataset = paddle.vision.datasets.MNIST(mode='train')
        val_dataset = paddle.vision.datasets.MNIST(mode='test')
I
itminner 已提交
111 112 113 114
        class_dim = 10
        image_shape = "1,28,28"
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
115 116
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
I
itminner 已提交
117 118 119 120 121 122
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    image_shape = [int(m) for m in image_shape.split(",")]
123 124
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
B
Bai Yifan 已提交
125 126 127
    image = paddle.static.data(
        name='image', shape=[None] + image_shape, dtype='float32')
    label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
I
itminner 已提交
128 129 130
    # model definition
    model = models.__dict__[args.model]()
    out = model.net(input=image, class_dim=class_dim)
B
Bai Yifan 已提交
131 132 133 134
    cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
    avg_cost = paddle.mean(x=cost)
    acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
    acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
I
itminner 已提交
135

B
Bai Yifan 已提交
136 137
    train_prog = paddle.static.default_main_program()
    val_program = paddle.static.default_main_program().clone(for_test=True)
I
itminner 已提交
138

B
Bai Yifan 已提交
139
    place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
I
itminner 已提交
140 141 142 143 144 145
    ############################################################################################################
    # 2. quantization transform programs (training aware)
    #    Make some quantization transforms in the graph before training and testing.
    #    According to the weight and activation quantization type, the graph will be added
    #    some fake quantize operators and fake dequantize operators.
    ############################################################################################################
146 147 148 149
    val_program = quant_aware(
        val_program, place, quant_config, scope=None, for_test=True)
    compiled_train_prog = quant_aware(
        train_prog, place, quant_config, scope=None, for_test=False)
I
itminner 已提交
150 151 152
    opt = create_optimizer(args)
    opt.minimize(avg_cost)

B
Bai Yifan 已提交
153 154
    exe = paddle.static.Executor(place)
    exe.run(paddle.static.default_startup_program())
I
itminner 已提交
155

L
Liufang Sang 已提交
156 157 158
    assert os.path.exists(
        args.pretrained_model), "pretrained_model doesn't exist"

I
itminner 已提交
159
    if args.pretrained_model:
B
Bai Yifan 已提交
160
        paddle.static.load(train_prog, args.pretrained_model, exe)
I
itminner 已提交
161

B
Bai Yifan 已提交
162 163 164
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()

B
Bai Yifan 已提交
165 166 167
    train_loader = paddle.io.DataLoader(
        train_dataset,
        places=places,
B
Bai Yifan 已提交
168
        feed_list=[image, label],
B
Bai Yifan 已提交
169 170 171 172 173 174 175 176
        drop_last=True,
        batch_size=args.batch_size,
        use_shared_memory=False,
        shuffle=True,
        num_workers=1)
    valid_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
B
Bai Yifan 已提交
177
        feed_list=[image, label],
B
Bai Yifan 已提交
178 179 180 181
        drop_last=False,
        batch_size=args.batch_size,
        use_shared_memory=False,
        shuffle=False)
I
itminner 已提交
182 183 184 185 186

    def test(epoch, program):
        batch_id = 0
        acc_top1_ns = []
        acc_top5_ns = []
B
Bai Yifan 已提交
187
        for data in valid_loader():
I
itminner 已提交
188 189
            start_time = time.time()
            acc_top1_n, acc_top5_n = exe.run(
B
Bai Yifan 已提交
190
                program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
I
itminner 已提交
191 192 193 194 195 196 197 198 199 200 201
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1_n),
                           np.mean(acc_top5_n), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1_n))
            acc_top5_ns.append(np.mean(acc_top5_n))
            batch_id += 1

202 203 204
        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
I
itminner 已提交
205 206 207 208 209
        return np.mean(np.array(acc_top1_ns))

    def train(epoch, compiled_train_prog):

        batch_id = 0
B
Bai Yifan 已提交
210
        for data in train_loader():
I
itminner 已提交
211 212 213
            start_time = time.time()
            loss_n, acc_top1_n, acc_top5_n = exe.run(
                compiled_train_prog,
B
Bai Yifan 已提交
214
                feed=data,
I
itminner 已提交
215 216 217 218 219 220 221 222 223 224 225 226
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            end_time = time.time()
            loss_n = np.mean(loss_n)
            acc_top1_n = np.mean(acc_top1_n)
            acc_top5_n = np.mean(acc_top5_n)
            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
                           end_time - start_time))
            batch_id += 1

B
Bai Yifan 已提交
227
    build_strategy = paddle.static.BuildStrategy()
L
Liufang Sang 已提交
228 229 230 231
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = False
    build_strategy.fuse_all_reduce_ops = False
    build_strategy.sync_batch_norm = False
B
Bai Yifan 已提交
232
    exec_strategy = paddle.static.ExecutionStrategy()
L
Liufang Sang 已提交
233 234 235 236 237
    compiled_train_prog = compiled_train_prog.with_data_parallel(
        loss_name=avg_cost.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

I
itminner 已提交
238 239 240
    ############################################################################################################
    # train loop
    ############################################################################################################
L
Liufang Sang 已提交
241 242
    best_acc1 = 0.0
    best_epoch = 0
I
itminner 已提交
243 244
    for i in range(args.num_epochs):
        train(i, compiled_train_prog)
L
Liufang Sang 已提交
245
        acc1 = test(i, val_program)
B
Bai Yifan 已提交
246 247 248
        paddle.static.save(
            program=val_program,
            model_path=os.path.join(args.checkpoint_dir, str(i)))
L
Liufang Sang 已提交
249 250 251
        if acc1 > best_acc1:
            best_acc1 = acc1
            best_epoch = i
B
Bai Yifan 已提交
252 253 254
            paddle.static.save(
                program=val_program,
                model_path=os.path.join(args.checkpoint_dir, 'best_model'))
L
Liufang Sang 已提交
255
    if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')):
B
Bai Yifan 已提交
256
        paddle.static.load(
L
Liufang Sang 已提交
257 258 259
            exe,
            dirname=os.path.join(args.checkpoint_dir, 'best_model'),
            main_program=val_program)
I
itminner 已提交
260 261 262 263 264
    ############################################################################################################
    # 3. Freeze the graph after training by adjusting the quantize
    #    operators' order for the inference.
    #    The dtype of float_program's weights is float32, but in int8 range.
    ############################################################################################################
I
itminner 已提交
265 266
    float_program, int8_program = convert(val_program, place, quant_config, \
                                                        scope=None, \
I
itminner 已提交
267
                                                        save_int8=True)
L
Liufang Sang 已提交
268 269
    print("eval best_model after convert")
    final_acc1 = test(best_epoch, float_program)
I
itminner 已提交
270 271 272 273
    ############################################################################################################
    # 4. Save inference model
    ############################################################################################################
    model_path = os.path.join(quantization_model_save_dir, args.model,
274 275
                              'act_' + quant_config['activation_quantize_type']
                              + '_w_' + quant_config['weight_quantize_type'])
I
itminner 已提交
276 277 278 279
    float_path = os.path.join(model_path, 'float')
    if not os.path.isdir(model_path):
        os.makedirs(model_path)

B
Bai Yifan 已提交
280
    paddle.static.save_inference_model(
I
itminner 已提交
281 282
        dirname=float_path,
        feeded_var_names=[image.name],
283 284
        target_vars=[out],
        executor=exe,
I
itminner 已提交
285 286 287 288 289 290
        main_program=float_program,
        model_filename=float_path + '/model',
        params_filename=float_path + '/params')


def main():
291
    paddle.enable_static()
I
itminner 已提交
292 293 294 295 296 297 298
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()