distill.py 9.9 KB
Newer Older
B
baiyfbupt 已提交
1 2 3 4 5 6 7 8 9 10 11 12
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
B
Bai Yifan 已提交
13
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
B
baiyfbupt 已提交
14
import models
15
from utility import add_arguments, print_arguments, _download, _decompress
16
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
B
baiyfbupt 已提交
17 18 19 20 21 22 23 24

logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
B
Bai Yifan 已提交
25
add_arg('batch_size',       int,  64,                 "Minibatch size.")
B
baiyfbupt 已提交
26
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
27
add_arg('save_inference',   bool, False,                "Whether to save inference model.")
B
baiyfbupt 已提交
28 29 30 31 32 33 34
add_arg('total_images',     int,  1281167,              "Training image number.")
add_arg('image_shape',      str,  "3,224,224",         "Input image size")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
B
Bai Yifan 已提交
35
add_arg('data',             str, "imagenet",                 "Which data to use. 'cifar10' or 'imagenet'")
36
add_arg('log_period',       int,  20,                 "Log period in batches.")
B
baiyfbupt 已提交
37 38
add_arg('model',            str,  "MobileNet",          "Set the network to use.")
add_arg('pretrained_model', str,  None,                "Whether to use pretrained model.")
B
Bai Yifan 已提交
39 40
add_arg('teacher_model',    str,  "ResNet50_vd",          "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str,  "./ResNet50_vd_pretrained",                "Whether to use pretrained model.")
B
baiyfbupt 已提交
41 42 43 44 45 46 47
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
B
Bai Yifan 已提交
48
    if args.use_gpu:
49
        devices_num = paddle.fluid.core.get_cuda_device_count()
B
Bai Yifan 已提交
50 51
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
B
Bai Yifan 已提交
52 53
    step = int(
        math.ceil(float(args.total_images) / args.batch_size) / devices_num)
B
baiyfbupt 已提交
54 55
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
56 57 58
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
59 60
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
61
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
62
    return learning_rate, optimizer
B
baiyfbupt 已提交
63 64 65


def cosine_decay(args):
66 67
    if args.use_gpu:
        devices_num = paddle.fluid.core.get_cuda_device_count()
B
Bai Yifan 已提交
68 69
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
B
Bai Yifan 已提交
70 71
    step = int(
        math.ceil(float(args.total_images) / args.batch_size) / devices_num)
72 73 74
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
75 76
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
77
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
78
    return learning_rate, optimizer
B
baiyfbupt 已提交
79 80 81 82 83 84 85 86 87 88


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
89
    if args.data == "cifar10":
90 91 92
        import paddle.dataset.cifar as reader
        train_reader = reader.train10()
        val_reader = reader.test10()
B
baiyfbupt 已提交
93
        class_dim = 10
94
        image_shape = "3,32,32"
B
baiyfbupt 已提交
95 96 97 98 99 100 101 102 103 104
    elif args.data == "imagenet":
        import imagenet_reader as reader
        train_reader = reader.train()
        val_reader = reader.val()
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]

105 106
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
107 108 109 110 111 112 113 114 115 116
    student_program = paddle.static.Program()
    s_startup = paddle.static.Program()

    with paddle.static.program_guard(student_program, s_startup):
        with paddle.fluid.unique_name.guard():
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
            label = paddle.static.data(
                name='label', shape=[None, 1], dtype='int64')
            train_loader = paddle.io.DataLoader.from_generator(
B
baiyfbupt 已提交
117 118 119 120
                feed_list=[image, label],
                capacity=64,
                use_double_buffer=True,
                iterable=True)
121
            valid_loader = paddle.io.DataLoader.from_generator(
B
baiyfbupt 已提交
122 123 124 125 126 127 128
                feed_list=[image, label],
                capacity=64,
                use_double_buffer=True,
                iterable=True)
            # model definition
            model = models.__dict__[args.model]()
            out = model.net(input=image, class_dim=class_dim)
129 130 131 132 133
            cost = paddle.nn.functional.loss.cross_entropy(
                input=out, label=label)
            avg_cost = paddle.mean(x=cost)
            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
B
baiyfbupt 已提交
134

135
    train_reader = paddle.batch(
B
baiyfbupt 已提交
136
        train_reader, batch_size=args.batch_size, drop_last=True)
137
    val_reader = paddle.batch(
B
baiyfbupt 已提交
138 139 140
        val_reader, batch_size=args.batch_size, drop_last=True)
    val_program = student_program.clone(for_test=True)

141 142 143 144 145
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    place = places[0]
    exe = paddle.static.Executor(place)

B
baiyfbupt 已提交
146 147 148 149 150
    train_loader.set_sample_list_generator(train_reader, places)
    valid_loader.set_sample_list_generator(val_reader, place)

    teacher_model = models.__dict__[args.teacher_model]()
    # define teacher program
151 152 153 154 155 156
    teacher_program = paddle.static.Program()
    t_startup = paddle.static.Program()
    with paddle.static.program_guard(teacher_program, t_startup):
        with paddle.fluid.unique_name.guard():
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
B
baiyfbupt 已提交
157 158 159
            predict = teacher_model.net(image, class_dim=class_dim)

    exe.run(t_startup)
B
Bai Yifan 已提交
160 161 162 163 164
    if not os.path.exists(args.teacher_pretrained_model):
        _download(
            'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
            '.')
        _decompress('./ResNet50_vd_pretrained.tar')
B
baiyfbupt 已提交
165 166 167 168 169
    assert args.teacher_pretrained_model and os.path.exists(
        args.teacher_pretrained_model
    ), "teacher_pretrained_model should be set when teacher_model is not None."

    def if_exist(var):
B
Bai Yifan 已提交
170
        exist = os.path.exists(
B
Bai Yifan 已提交
171
            os.path.join(args.teacher_pretrained_model, var.name))
B
Bai Yifan 已提交
172 173 174 175
        if args.data == "cifar10" and (var.name == 'fc_0.w_0' or
                                       var.name == 'fc_0.b_0'):
            exist = False
        return exist
B
baiyfbupt 已提交
176

177
    paddle.static.load(teacher_program, args.teacher_pretrained_model, exe)
B
baiyfbupt 已提交
178 179

    data_name_map = {'image': 'image'}
180 181
    merge(teacher_program, student_program, data_name_map, place)

182
    with paddle.static.program_guard(student_program, s_startup):
B
Bai Yifan 已提交
183 184 185 186
        distill_loss = soft_label_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0",
                                       student_program)
        loss = avg_cost + distill_loss
        lr, opt = create_optimizer(args)
B
baiyfbupt 已提交
187 188
        opt.minimize(loss)
    exe.run(s_startup)
189
    build_strategy = paddle.static.BuildStrategy()
B
baiyfbupt 已提交
190
    build_strategy.fuse_all_reduce_ops = False
191 192 193
    parallel_main = paddle.static.CompiledProgram(
        student_program).with_data_parallel(
            loss_name=loss.name, build_strategy=build_strategy)
B
baiyfbupt 已提交
194 195 196

    for epoch_id in range(args.num_epochs):
        for step_id, data in enumerate(train_loader):
197
            loss_1, loss_2, loss_3 = exe.run(
B
baiyfbupt 已提交
198 199
                parallel_main,
                feed=data,
200
                fetch_list=[loss.name, avg_cost.name, distill_loss.name])
B
baiyfbupt 已提交
201 202
            if step_id % args.log_period == 0:
                _logger.info(
B
Bai Yifan 已提交
203
                    "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
204 205 206
                    format(epoch_id, step_id,
                           lr.get_lr(), loss_1[0], loss_2[0], loss_3[0]))
            lr.step()
B
baiyfbupt 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220
        val_acc1s = []
        val_acc5s = []
        for step_id, data in enumerate(valid_loader):
            val_loss, val_acc1, val_acc5 = exe.run(
                val_program,
                data,
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            val_acc1s.append(val_acc1)
            val_acc5s.append(val_acc5)
            if step_id % args.log_period == 0:
                _logger.info(
                    "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
                    format(epoch_id, step_id, val_loss[0], val_acc1[0],
                           val_acc5[0]))
221
        if args.save_inference:
222
            paddle.static.save_inference_model(
223 224
                os.path.join("./saved_models", str(epoch_id)), ["image"],
                [out], exe, student_program)
B
baiyfbupt 已提交
225 226 227 228 229 230 231 232 233 234 235
        _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
            epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
236
    paddle.enable_static()
B
baiyfbupt 已提交
237
    main()