train.py 9.9 KB
Newer Older
M
minghaoBD 已提交
1 2 3 4 5
import paddle
import os
import sys
import argparse
import numpy as np
6
from paddleslim import UnstructuredPruner
M
minghaoBD 已提交
7 8 9 10 11 12 13 14 15 16 17
sys.path.append(
    os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
import paddle.nn.functional as F
import functools
from paddle.vision.models import mobilenet_v1
import time
import logging
from paddleslim.common import get_logger
import paddle.distributed as dist
18
from paddle.distributed import ParallelEnv
M
minghaoBD 已提交
19 20 21 22 23 24

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
25
add_arg('batch_size',       int,  64 * 4,                 "Minibatch size.")
M
minghaoBD 已提交
26
add_arg('batch_size_for_validation',       int,  64,                 "Minibatch size for validation.")
M
minghaoBD 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
add_arg('use_gpu',          bool, True,                "Whether to use GPU or not.")
add_arg('lr',               float,  0.1,               "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy.")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter.")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate.")
add_arg('ratio',            float,  0.3,               "The ratio to set zeros, the smaller part bounded by the ratio will be zeros.")
add_arg('pruning_mode',            str,  'ratio',               "the pruning mode: whether by ratio or by threshold.")
add_arg('threshold',            float,  0.001,               "The threshold to set zeros.")
add_arg('num_epochs',       int,  120,               "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data',             str, "cifar10",                 "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period',       int, 100,                 "Log period in batches.")
add_arg('test_period',      int, 1,                 "Test period in epoches.")
40
add_arg('pretrained_model', str, None,              "The pretrained model the load. Default: None.")
M
minghaoBD 已提交
41 42 43
add_arg('model_path',       str, "./models",         "The path to save model.")
add_arg('model_period',     int, 10,             "The period to save model in epochs.")
add_arg('resume_epoch',     int, -1,             "The epoch to resume training.")
44
add_arg('num_workers',     int, 16,             "number of workers when loading dataset.")
M
minghaoBD 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
# yapf: enable


def piecewise_decay(args, step_per_epoch, model):
    bd = [step_per_epoch * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)

    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=model.parameters())
    return optimizer, learning_rate


def cosine_decay(args, step_per_epoch, model):
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
        learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=model.parameters())
    return optimizer, learning_rate


def create_optimizer(args, step_per_epoch, model):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args, step_per_epoch, model)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args, step_per_epoch, model)


def compress(args):
80 81 82 83 84 85 86 87 88 89
    if args.use_gpu:
        place = paddle.set_device('gpu')
    else:
        place = paddle.set_device('cpu')

    trainer_num = paddle.distributed.get_world_size()
    use_data_parallel = trainer_num != 1
    if use_data_parallel:
        dist.init_parallel_env()

M
minghaoBD 已提交
90 91 92 93
    train_reader = None
    test_reader = None
    if args.data == "imagenet":
        import imagenet_reader as reader
94 95
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
M
minghaoBD 已提交
96 97 98 99 100 101 102 103 104 105 106 107
        class_dim = 1000
    elif args.data == "cifar10":
        normalize = T.Normalize(
            mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW')
        transform = T.Compose([T.Transpose(), normalize])
        train_dataset = paddle.vision.datasets.Cifar10(
            mode='train', backend='cv2', transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(
            mode='test', backend='cv2', transform=transform)
        class_dim = 10
    else:
        raise ValueError("{} is not supported.".format(args.data))
108 109 110 111

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

M
minghaoBD 已提交
112 113
    train_loader = paddle.io.DataLoader(
        train_dataset,
114 115
        places=place,
        batch_sampler=batch_sampler,
M
minghaoBD 已提交
116 117 118
        return_list=True,
        num_workers=args.num_workers,
        use_shared_memory=True)
119

M
minghaoBD 已提交
120 121
    valid_loader = paddle.io.DataLoader(
        val_dataset,
122
        places=place,
M
minghaoBD 已提交
123 124
        drop_last=False,
        return_list=True,
M
minghaoBD 已提交
125
        batch_size=args.batch_size_for_validation,
M
minghaoBD 已提交
126 127
        shuffle=False,
        use_shared_memory=True)
128 129
    step_per_epoch = int(
        np.ceil(len(train_dataset) / args.batch_size / ParallelEnv().nranks))
M
minghaoBD 已提交
130 131
    # model definition
    model = mobilenet_v1(num_classes=class_dim, pretrained=True)
132 133 134
    if ParallelEnv().nranks > 1:
        model = paddle.DataParallel(model)

135 136
    if args.pretrained_model is not None:
        model.set_state_dict(paddle.load(args.pretrained_model))
M
minghaoBD 已提交
137

138
    opt, learning_rate = create_optimizer(args, step_per_epoch, model)
M
minghaoBD 已提交
139 140

    def test(epoch):
141
        model.eval()
M
minghaoBD 已提交
142 143 144 145 146 147
        acc_top1_ns = []
        acc_top5_ns = []
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
148 149
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)
M
minghaoBD 已提交
150

151
            logits = model(x_data)
M
minghaoBD 已提交
152 153 154
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
M
minghaoBD 已提交
155
            end_time = time.time()
M
minghaoBD 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1.numpy()),
                           np.mean(acc_top5.numpy()), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))

        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(
                acc_top1_ns, dtype="object")),
            np.mean(np.array(
                acc_top5_ns, dtype="object"))))

    def train(epoch):
173
        model.train()
174 175 176 177 178
        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()

M
minghaoBD 已提交
179
        for batch_id, data in enumerate(train_loader):
180
            train_reader_cost += time.time() - reader_start
M
minghaoBD 已提交
181 182
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
183 184
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)
M
minghaoBD 已提交
185

186
            train_start = time.time()
187
            logits = model(x_data)
M
minghaoBD 已提交
188 189 190
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
191

M
minghaoBD 已提交
192 193
            loss.backward()
            opt.step()
194
            learning_rate.step()
M
minghaoBD 已提交
195 196
            opt.clear_grad()
            pruner.step()
197
            train_run_cost += time.time() - train_start
198
            total_samples += args.batch_size
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216

            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
                    format(epoch, batch_id,
                           opt.get_lr(),
                           np.mean(loss.numpy()),
                           np.mean(acc_top1.numpy()),
                           np.mean(acc_top5.numpy()), train_reader_cost /
                           args.log_period, (train_reader_cost + train_run_cost
                                             ) / args.log_period, total_samples
                           / args.log_period, total_samples / (
                               train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0

            reader_start = time.time()
M
minghaoBD 已提交
217 218

    pruner = UnstructuredPruner(
219
        model,
M
minghaoBD 已提交
220 221 222
        mode=args.pruning_mode,
        ratio=args.ratio,
        threshold=args.threshold)
223

M
minghaoBD 已提交
224 225
    for i in range(args.resume_epoch + 1, args.num_epochs):
        train(i)
226
        if (i + 1) % args.test_period == 0:
M
minghaoBD 已提交
227 228 229
            pruner.update_params()
            _logger.info(
                "The current density of the pruned model is: {}%".format(
230
                    round(100 * UnstructuredPruner.total_sparse(model), 2)))
M
minghaoBD 已提交
231
            test(i)
232
        if (i + 1) % args.model_period == 0:
M
minghaoBD 已提交
233
            pruner.update_params()
234
            paddle.save(model.state_dict(),
M
minghaoBD 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247
                        os.path.join(args.model_path, "model-pruned.pdparams"))
            paddle.save(opt.state_dict(),
                        os.path.join(args.model_path, "opt-pruned.pdopt"))


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()