distill.py 9.7 KB
Newer Older
B
baiyfbupt 已提交
1 2 3 4 5 6 7 8 9 10 11 12
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
B
Bai Yifan 已提交
13
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
B
baiyfbupt 已提交
14
import models
15
from utility import add_arguments, print_arguments, _download, _decompress
W
whs 已提交
16
from paddleslim.dist import merge, l2, soft_label
B
baiyfbupt 已提交
17 18 19 20 21 22 23 24

logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
Z
zhouzj 已提交
25
add_arg('batch_size',       int,  256,                 "Minibatch size.")
B
baiyfbupt 已提交
26
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
27
add_arg('save_inference',   bool, False,                "Whether to save inference model.")
B
baiyfbupt 已提交
28 29 30 31 32 33 34
add_arg('total_images',     int,  1281167,              "Training image number.")
add_arg('image_shape',      str,  "3,224,224",         "Input image size")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
B
Bai Yifan 已提交
35
add_arg('data',             str, "imagenet",                 "Which data to use. 'cifar10' or 'imagenet'")
36
add_arg('log_period',       int,  20,                 "Log period in batches.")
B
baiyfbupt 已提交
37 38
add_arg('model',            str,  "MobileNet",          "Set the network to use.")
add_arg('pretrained_model', str,  None,                "Whether to use pretrained model.")
B
Bai Yifan 已提交
39 40
add_arg('teacher_model',    str,  "ResNet50_vd",          "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str,  "./ResNet50_vd_pretrained",                "Whether to use pretrained model.")
B
baiyfbupt 已提交
41 42 43 44 45 46 47
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
Z
zhouzj 已提交
48
    step = int(math.ceil(float(args.total_images) / args.batch_size))
B
baiyfbupt 已提交
49 50
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
51 52 53
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
54 55
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
56
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
57
    return learning_rate, optimizer
B
baiyfbupt 已提交
58 59 60


def cosine_decay(args):
Z
zhouzj 已提交
61
    step = int(math.ceil(float(args.total_images) / args.batch_size))
62 63 64
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
65 66
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
67
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
68
    return learning_rate, optimizer
B
baiyfbupt 已提交
69 70 71 72 73 74 75 76 77 78


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
79
    if args.data == "cifar10":
B
Bai Yifan 已提交
80 81
        train_dataset = paddle.vision.datasets.Cifar10(mode='train')
        val_dataset = paddle.vision.datasets.Cifar10(mode='test')
B
baiyfbupt 已提交
82
        class_dim = 10
83
        image_shape = "3,32,32"
B
baiyfbupt 已提交
84 85
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
86 87
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
B
baiyfbupt 已提交
88 89 90 91 92 93
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]

94 95
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
                                                                     model_list)
96 97
    student_program = paddle.static.Program()
    s_startup = paddle.static.Program()
B
Bai Yifan 已提交
98 99 100
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    place = places[0]
Z
zhouzj 已提交
101
    if args.use_gpu:
W
whs 已提交
102
        devices_num = paddle.framework.core.get_cuda_device_count()
Z
zhouzj 已提交
103 104
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
105
    with paddle.static.program_guard(student_program, s_startup):
W
whs 已提交
106
        with paddle.utils.unique_name.guard():
107 108 109 110
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
            label = paddle.static.data(
                name='label', shape=[None, 1], dtype='int64')
B
Bai Yifan 已提交
111 112 113
            train_loader = paddle.io.DataLoader(
                train_dataset,
                places=places,
B
baiyfbupt 已提交
114
                feed_list=[image, label],
B
Bai Yifan 已提交
115
                drop_last=True,
Z
zhouzj 已提交
116
                batch_size=int(args.batch_size / devices_num),
117
                return_list=False,
B
Bai Yifan 已提交
118
                shuffle=True,
119 120
                use_shared_memory=True,
                num_workers=4)
B
Bai Yifan 已提交
121 122 123
            valid_loader = paddle.io.DataLoader(
                val_dataset,
                places=place,
B
baiyfbupt 已提交
124
                feed_list=[image, label],
B
Bai Yifan 已提交
125
                drop_last=False,
126
                return_list=False,
127
                use_shared_memory=True,
B
Bai Yifan 已提交
128 129
                batch_size=args.batch_size,
                shuffle=False)
B
baiyfbupt 已提交
130 131 132
            # model definition
            model = models.__dict__[args.model]()
            out = model.net(input=image, class_dim=class_dim)
133 134 135 136 137
            cost = paddle.nn.functional.loss.cross_entropy(
                input=out, label=label)
            avg_cost = paddle.mean(x=cost)
            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
B
baiyfbupt 已提交
138 139

    val_program = student_program.clone(for_test=True)
140 141
    exe = paddle.static.Executor(place)

B
baiyfbupt 已提交
142 143
    teacher_model = models.__dict__[args.teacher_model]()
    # define teacher program
144 145 146
    teacher_program = paddle.static.Program()
    t_startup = paddle.static.Program()
    with paddle.static.program_guard(teacher_program, t_startup):
W
whs 已提交
147
        with paddle.utils.unique_name.guard():
148 149
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
B
baiyfbupt 已提交
150 151 152
            predict = teacher_model.net(image, class_dim=class_dim)

    exe.run(t_startup)
B
Bai Yifan 已提交
153 154 155 156 157
    if not os.path.exists(args.teacher_pretrained_model):
        _download(
            'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
            '.')
        _decompress('./ResNet50_vd_pretrained.tar')
B
baiyfbupt 已提交
158 159 160 161 162
    assert args.teacher_pretrained_model and os.path.exists(
        args.teacher_pretrained_model
    ), "teacher_pretrained_model should be set when teacher_model is not None."

    def if_exist(var):
B
Bai Yifan 已提交
163
        exist = os.path.exists(
B
Bai Yifan 已提交
164
            os.path.join(args.teacher_pretrained_model, var.name))
B
Bai Yifan 已提交
165 166 167 168
        if args.data == "cifar10" and (var.name == 'fc_0.w_0' or
                                       var.name == 'fc_0.b_0'):
            exist = False
        return exist
B
baiyfbupt 已提交
169

170
    paddle.static.load(teacher_program, args.teacher_pretrained_model, exe)
B
baiyfbupt 已提交
171 172

    data_name_map = {'image': 'image'}
173 174
    merge(teacher_program, student_program, data_name_map, place)

175
    with paddle.static.program_guard(student_program, s_startup):
C
ceci3 已提交
176 177
        distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
                                  student_program)
B
Bai Yifan 已提交
178 179
        loss = avg_cost + distill_loss
        lr, opt = create_optimizer(args)
B
baiyfbupt 已提交
180 181
        opt.minimize(loss)
    exe.run(s_startup)
182
    build_strategy = paddle.static.BuildStrategy()
B
baiyfbupt 已提交
183
    build_strategy.fuse_all_reduce_ops = False
184 185 186
    parallel_main = paddle.static.CompiledProgram(
        student_program).with_data_parallel(
            loss_name=loss.name, build_strategy=build_strategy)
B
baiyfbupt 已提交
187 188 189

    for epoch_id in range(args.num_epochs):
        for step_id, data in enumerate(train_loader):
190
            loss_1, loss_2, loss_3 = exe.run(
B
baiyfbupt 已提交
191 192
                parallel_main,
                feed=data,
193
                fetch_list=[loss.name, avg_cost.name, distill_loss.name])
B
baiyfbupt 已提交
194 195
            if step_id % args.log_period == 0:
                _logger.info(
B
Bai Yifan 已提交
196
                    "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
197 198 199
                    format(epoch_id, step_id,
                           lr.get_lr(), loss_1[0], loss_2[0], loss_3[0]))
            lr.step()
B
baiyfbupt 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213
        val_acc1s = []
        val_acc5s = []
        for step_id, data in enumerate(valid_loader):
            val_loss, val_acc1, val_acc5 = exe.run(
                val_program,
                data,
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            val_acc1s.append(val_acc1)
            val_acc5s.append(val_acc5)
            if step_id % args.log_period == 0:
                _logger.info(
                    "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
                    format(epoch_id, step_id, val_loss[0], val_acc1[0],
                           val_acc5[0]))
214
        if args.save_inference:
W
whs 已提交
215 216 217 218
            paddle.static.save_inference_model(
                os.path.join("./saved_models", str(epoch_id)), [image], [out],
                exe,
                program=student_program)
B
baiyfbupt 已提交
219 220 221 222 223 224 225 226 227 228 229
        _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
            epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
230
    paddle.enable_static()
B
baiyfbupt 已提交
231
    main()