distill.py 9.8 KB
Newer Older
B
baiyfbupt 已提交
1 2 3 4 5 6 7 8 9 10 11 12
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
B
Bai Yifan 已提交
13
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
B
baiyfbupt 已提交
14
import models
15
from utility import add_arguments, print_arguments, _download, _decompress
16
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
B
baiyfbupt 已提交
17 18 19 20 21 22 23 24

logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
B
Bai Yifan 已提交
25
add_arg('batch_size',       int,  64,                 "Minibatch size.")
B
baiyfbupt 已提交
26
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
27
add_arg('save_inference',   bool, False,                "Whether to save inference model.")
B
baiyfbupt 已提交
28 29 30 31 32 33 34
add_arg('total_images',     int,  1281167,              "Training image number.")
add_arg('image_shape',      str,  "3,224,224",         "Input image size")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
B
Bai Yifan 已提交
35
add_arg('data',             str, "imagenet",                 "Which data to use. 'cifar10' or 'imagenet'")
36
add_arg('log_period',       int,  20,                 "Log period in batches.")
B
baiyfbupt 已提交
37 38
add_arg('model',            str,  "MobileNet",          "Set the network to use.")
add_arg('pretrained_model', str,  None,                "Whether to use pretrained model.")
B
Bai Yifan 已提交
39 40
add_arg('teacher_model',    str,  "ResNet50_vd",          "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str,  "./ResNet50_vd_pretrained",                "Whether to use pretrained model.")
B
baiyfbupt 已提交
41 42 43 44 45 46 47
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
B
Bai Yifan 已提交
48
    if args.use_gpu:
49
        devices_num = paddle.fluid.core.get_cuda_device_count()
B
Bai Yifan 已提交
50 51
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
B
Bai Yifan 已提交
52 53
    step = int(
        math.ceil(float(args.total_images) / args.batch_size) / devices_num)
B
baiyfbupt 已提交
54 55
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
56 57 58
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
59 60
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
61
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
62
    return learning_rate, optimizer
B
baiyfbupt 已提交
63 64 65


def cosine_decay(args):
66 67
    if args.use_gpu:
        devices_num = paddle.fluid.core.get_cuda_device_count()
B
Bai Yifan 已提交
68 69
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
B
Bai Yifan 已提交
70 71
    step = int(
        math.ceil(float(args.total_images) / args.batch_size) / devices_num)
72 73 74
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
75 76
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
77
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
78
    return learning_rate, optimizer
B
baiyfbupt 已提交
79 80 81 82 83 84 85 86 87 88


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
89
    if args.data == "cifar10":
B
Bai Yifan 已提交
90 91
        train_dataset = paddle.vision.datasets.Cifar10(mode='train')
        val_dataset = paddle.vision.datasets.Cifar10(mode='test')
B
baiyfbupt 已提交
92
        class_dim = 10
93
        image_shape = "3,32,32"
B
baiyfbupt 已提交
94 95
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
96 97
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
B
baiyfbupt 已提交
98 99 100 101 102 103
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]

104 105
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
106 107
    student_program = paddle.static.Program()
    s_startup = paddle.static.Program()
B
Bai Yifan 已提交
108 109 110
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    place = places[0]
111 112 113 114 115 116 117

    with paddle.static.program_guard(student_program, s_startup):
        with paddle.fluid.unique_name.guard():
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
            label = paddle.static.data(
                name='label', shape=[None, 1], dtype='int64')
B
Bai Yifan 已提交
118 119 120
            train_loader = paddle.io.DataLoader(
                train_dataset,
                places=places,
B
baiyfbupt 已提交
121
                feed_list=[image, label],
B
Bai Yifan 已提交
122 123 124 125 126 127 128 129
                drop_last=True,
                batch_size=args.batch_size,
                shuffle=True,
                use_shared_memory=False,
                num_workers=1)
            valid_loader = paddle.io.DataLoader(
                val_dataset,
                places=place,
B
baiyfbupt 已提交
130
                feed_list=[image, label],
B
Bai Yifan 已提交
131 132 133 134
                drop_last=False,
                use_shared_memory=False,
                batch_size=args.batch_size,
                shuffle=False)
B
baiyfbupt 已提交
135 136 137
            # model definition
            model = models.__dict__[args.model]()
            out = model.net(input=image, class_dim=class_dim)
138 139 140 141 142
            cost = paddle.nn.functional.loss.cross_entropy(
                input=out, label=label)
            avg_cost = paddle.mean(x=cost)
            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
B
baiyfbupt 已提交
143 144

    val_program = student_program.clone(for_test=True)
145 146
    exe = paddle.static.Executor(place)

B
baiyfbupt 已提交
147 148
    teacher_model = models.__dict__[args.teacher_model]()
    # define teacher program
149 150 151 152 153 154
    teacher_program = paddle.static.Program()
    t_startup = paddle.static.Program()
    with paddle.static.program_guard(teacher_program, t_startup):
        with paddle.fluid.unique_name.guard():
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
B
baiyfbupt 已提交
155 156 157
            predict = teacher_model.net(image, class_dim=class_dim)

    exe.run(t_startup)
B
Bai Yifan 已提交
158 159 160 161 162
    if not os.path.exists(args.teacher_pretrained_model):
        _download(
            'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
            '.')
        _decompress('./ResNet50_vd_pretrained.tar')
B
baiyfbupt 已提交
163 164 165 166 167
    assert args.teacher_pretrained_model and os.path.exists(
        args.teacher_pretrained_model
    ), "teacher_pretrained_model should be set when teacher_model is not None."

    def if_exist(var):
B
Bai Yifan 已提交
168
        exist = os.path.exists(
B
Bai Yifan 已提交
169
            os.path.join(args.teacher_pretrained_model, var.name))
B
Bai Yifan 已提交
170 171 172 173
        if args.data == "cifar10" and (var.name == 'fc_0.w_0' or
                                       var.name == 'fc_0.b_0'):
            exist = False
        return exist
B
baiyfbupt 已提交
174

175
    paddle.static.load(teacher_program, args.teacher_pretrained_model, exe)
B
baiyfbupt 已提交
176 177

    data_name_map = {'image': 'image'}
178 179
    merge(teacher_program, student_program, data_name_map, place)

180
    with paddle.static.program_guard(student_program, s_startup):
B
Bai Yifan 已提交
181 182 183 184
        distill_loss = soft_label_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0",
                                       student_program)
        loss = avg_cost + distill_loss
        lr, opt = create_optimizer(args)
B
baiyfbupt 已提交
185 186
        opt.minimize(loss)
    exe.run(s_startup)
187
    build_strategy = paddle.static.BuildStrategy()
B
baiyfbupt 已提交
188
    build_strategy.fuse_all_reduce_ops = False
189 190 191
    parallel_main = paddle.static.CompiledProgram(
        student_program).with_data_parallel(
            loss_name=loss.name, build_strategy=build_strategy)
B
baiyfbupt 已提交
192 193 194

    for epoch_id in range(args.num_epochs):
        for step_id, data in enumerate(train_loader):
195
            loss_1, loss_2, loss_3 = exe.run(
B
baiyfbupt 已提交
196 197
                parallel_main,
                feed=data,
198
                fetch_list=[loss.name, avg_cost.name, distill_loss.name])
B
baiyfbupt 已提交
199 200
            if step_id % args.log_period == 0:
                _logger.info(
B
Bai Yifan 已提交
201
                    "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
202 203 204
                    format(epoch_id, step_id,
                           lr.get_lr(), loss_1[0], loss_2[0], loss_3[0]))
            lr.step()
B
baiyfbupt 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218
        val_acc1s = []
        val_acc5s = []
        for step_id, data in enumerate(valid_loader):
            val_loss, val_acc1, val_acc5 = exe.run(
                val_program,
                data,
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            val_acc1s.append(val_acc1)
            val_acc5s.append(val_acc5)
            if step_id % args.log_period == 0:
                _logger.info(
                    "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
                    format(epoch_id, step_id, val_loss[0], val_acc1[0],
                           val_acc5[0]))
219
        if args.save_inference:
220
            paddle.static.save_inference_model(
221 222
                os.path.join("./saved_models", str(epoch_id)), ["image"],
                [out], exe, student_program)
B
baiyfbupt 已提交
223 224 225 226 227 228 229 230 231 232 233
        _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
            epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
234
    paddle.enable_static()
B
baiyfbupt 已提交
235
    main()