distill.py 9.6 KB
Newer Older
B
baiyfbupt 已提交
1 2 3 4 5 6 7 8 9 10 11 12
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
B
Bai Yifan 已提交
13
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
B
baiyfbupt 已提交
14
import models
15
from utility import add_arguments, print_arguments, _download, _decompress
W
whs 已提交
16
from paddleslim.dist import merge, l2, soft_label
B
baiyfbupt 已提交
17

18 19 20
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy

B
baiyfbupt 已提交
21 22 23 24 25 26 27
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
Z
zhouzj 已提交
28
add_arg('batch_size',       int,  256,                 "Minibatch size.")
B
baiyfbupt 已提交
29
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
30
add_arg('save_inference',   bool, False,                "Whether to save inference model.")
B
baiyfbupt 已提交
31 32 33 34 35 36 37
add_arg('total_images',     int,  1281167,              "Training image number.")
add_arg('image_shape',      str,  "3,224,224",         "Input image size")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
B
Bai Yifan 已提交
38
add_arg('data',             str, "imagenet",                 "Which data to use. 'cifar10' or 'imagenet'")
39
add_arg('log_period',       int,  20,                 "Log period in batches.")
B
baiyfbupt 已提交
40 41
add_arg('model',            str,  "MobileNet",          "Set the network to use.")
add_arg('pretrained_model', str,  None,                "Whether to use pretrained model.")
B
Bai Yifan 已提交
42 43
add_arg('teacher_model',    str,  "ResNet50_vd",          "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str,  "./ResNet50_vd_pretrained",                "Whether to use pretrained model.")
B
baiyfbupt 已提交
44 45 46 47 48 49 50
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable

model_list = [m for m in dir(models) if "__" not in m]


def piecewise_decay(args):
Z
zhouzj 已提交
51
    step = int(math.ceil(float(args.total_images) / args.batch_size))
B
baiyfbupt 已提交
52 53
    bd = [step * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
54 55 56
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
57 58
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
59
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
60
    return learning_rate, optimizer
B
baiyfbupt 已提交
61 62 63


def cosine_decay(args):
Z
zhouzj 已提交
64
    step = int(math.ceil(float(args.total_images) / args.batch_size))
65 66 67
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
    optimizer = paddle.optimizer.Momentum(
B
baiyfbupt 已提交
68 69
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
70
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
B
Bai Yifan 已提交
71
    return learning_rate, optimizer
B
baiyfbupt 已提交
72 73 74 75 76 77 78 79 80 81


def create_optimizer(args):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args)


def compress(args):
82 83 84

    fleet.init(is_collective=True)

85
    if args.data == "cifar10":
B
Bai Yifan 已提交
86 87
        train_dataset = paddle.vision.datasets.Cifar10(mode='train')
        val_dataset = paddle.vision.datasets.Cifar10(mode='test')
B
baiyfbupt 已提交
88
        class_dim = 10
89
        image_shape = "3,32,32"
B
baiyfbupt 已提交
90 91
    elif args.data == "imagenet":
        import imagenet_reader as reader
B
Bai Yifan 已提交
92 93
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
B
baiyfbupt 已提交
94 95 96 97 98 99
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))
    image_shape = [int(m) for m in image_shape.split(",")]

100 101
    assert args.model in model_list, "{} is not in lists: {}".format(
        args.model, model_list)
102 103
    student_program = paddle.static.Program()
    s_startup = paddle.static.Program()
B
Bai Yifan 已提交
104 105 106
    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    place = places[0]
Z
zhouzj 已提交
107
    if args.use_gpu:
W
whs 已提交
108
        devices_num = paddle.framework.core.get_cuda_device_count()
Z
zhouzj 已提交
109 110
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))
111
    with paddle.static.program_guard(student_program, s_startup):
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        image = paddle.static.data(
            name='image', shape=[None] + image_shape, dtype='float32')
        label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
        sampler = paddle.io.DistributedBatchSampler(
            train_dataset,
            shuffle=False,
            drop_last=True,
            batch_size=args.batch_size)
        train_loader = paddle.io.DataLoader(
            train_dataset,
            places=places,
            feed_list=[image, label],
            batch_sampler=sampler,
            return_list=False,
            use_shared_memory=False,
            num_workers=4)
        valid_loader = paddle.io.DataLoader(
            val_dataset,
            places=place,
            feed_list=[image, label],
            drop_last=False,
            return_list=False,
            use_shared_memory=False,
            batch_size=args.batch_size,
            shuffle=False)
        # model definition
        model = models.__dict__[args.model]()
        out = model.net(input=image, class_dim=class_dim)
        cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
        avg_cost = paddle.mean(x=cost)
        acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
        acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
B
baiyfbupt 已提交
144 145

    val_program = student_program.clone(for_test=True)
146 147
    exe = paddle.static.Executor(place)

B
baiyfbupt 已提交
148 149
    teacher_model = models.__dict__[args.teacher_model]()
    # define teacher program
150 151 152
    teacher_program = paddle.static.Program()
    t_startup = paddle.static.Program()
    with paddle.static.program_guard(teacher_program, t_startup):
W
whs 已提交
153
        with paddle.utils.unique_name.guard():
154 155
            image = paddle.static.data(
                name='image', shape=[None] + image_shape, dtype='float32')
B
baiyfbupt 已提交
156 157 158
            predict = teacher_model.net(image, class_dim=class_dim)

    exe.run(t_startup)
B
Bai Yifan 已提交
159 160 161 162 163
    if not os.path.exists(args.teacher_pretrained_model):
        _download(
            'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
            '.')
        _decompress('./ResNet50_vd_pretrained.tar')
B
baiyfbupt 已提交
164 165 166 167 168
    assert args.teacher_pretrained_model and os.path.exists(
        args.teacher_pretrained_model
    ), "teacher_pretrained_model should be set when teacher_model is not None."

    def if_exist(var):
B
Bai Yifan 已提交
169
        exist = os.path.exists(
B
Bai Yifan 已提交
170
            os.path.join(args.teacher_pretrained_model, var.name))
B
Bai Yifan 已提交
171 172 173 174
        if args.data == "cifar10" and (var.name == 'fc_0.w_0' or
                                       var.name == 'fc_0.b_0'):
            exist = False
        return exist
B
baiyfbupt 已提交
175

176
    paddle.static.load(teacher_program, args.teacher_pretrained_model, exe)
B
baiyfbupt 已提交
177 178

    data_name_map = {'image': 'image'}
179 180
    merge(teacher_program, student_program, data_name_map, place)

181 182 183 184
    build_strategy = paddle.static.BuildStrategy()
    dist_strategy = DistributedStrategy()
    dist_strategy.build_strategy = build_strategy

185
    with paddle.static.program_guard(student_program, s_startup):
C
ceci3 已提交
186 187
        distill_loss = soft_label("teacher_fc_0.tmp_0", "fc_0.tmp_0",
                                  student_program)
B
Bai Yifan 已提交
188 189
        loss = avg_cost + distill_loss
        lr, opt = create_optimizer(args)
190
        opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
B
baiyfbupt 已提交
191 192
        opt.minimize(loss)
    exe.run(s_startup)
193
    parallel_main = student_program
B
baiyfbupt 已提交
194 195 196

    for epoch_id in range(args.num_epochs):
        for step_id, data in enumerate(train_loader):
197
            loss_1, loss_2, loss_3 = exe.run(
B
baiyfbupt 已提交
198 199
                parallel_main,
                feed=data,
200
                fetch_list=[loss.name, avg_cost.name, distill_loss.name])
B
baiyfbupt 已提交
201 202
            if step_id % args.log_period == 0:
                _logger.info(
B
Bai Yifan 已提交
203
                    "train_epoch {} step {} lr {:.6f}, loss {:.6f}, class loss {:.6f}, distill loss {:.6f}".
204
                    format(epoch_id, step_id,
205
                           lr.get_lr(), loss_1, loss_2, loss_3))
206
            lr.step()
B
baiyfbupt 已提交
207 208 209 210 211 212 213 214 215 216 217 218
        val_acc1s = []
        val_acc5s = []
        for step_id, data in enumerate(valid_loader):
            val_loss, val_acc1, val_acc5 = exe.run(
                val_program,
                data,
                fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
            val_acc1s.append(val_acc1)
            val_acc5s.append(val_acc5)
            if step_id % args.log_period == 0:
                _logger.info(
                    "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
219
                    format(epoch_id, step_id, val_loss, val_acc1, val_acc5))
220
        if args.save_inference:
W
whs 已提交
221 222 223 224
            paddle.static.save_inference_model(
                os.path.join("./saved_models", str(epoch_id)), [image], [out],
                exe,
                program=student_program)
B
baiyfbupt 已提交
225 226 227 228 229 230 231 232 233 234 235
        _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
            epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
236
    paddle.enable_static()
B
baiyfbupt 已提交
237
    main()