train.py 13.2 KB
Newer Older
M
minghaoBD 已提交
1 2 3 4 5
import paddle
import os
import sys
import argparse
import numpy as np
6
from paddleslim import UnstructuredPruner, GMPUnstructuredPruner
M
minghaoBD 已提交
7 8 9 10 11 12 13 14 15 16 17
sys.path.append(
    os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
import paddle.nn.functional as F
import functools
from paddle.vision.models import mobilenet_v1
import time
import logging
from paddleslim.common import get_logger
import paddle.distributed as dist
18
from paddle.distributed import ParallelEnv
M
minghaoBD 已提交
19 20 21 22 23 24

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
25
add_arg('use_gpu',          bool, True,               "Whether to use GPU for training or not. Default: True")
26 27 28 29 30 31 32 33 34 35
add_arg('batch_size',       int,  64,                 "Minibatch size. Default: 64")
add_arg('batch_size_for_validation',       int,  64,                 "Minibatch size for validation. Default: 64")
add_arg('lr',               float,  0.05,               "The learning rate used to fine-tune pruned model. Default: 0.05")
add_arg('lr_strategy',      str,  "piecewise_decay",   "The learning rate decay strategy. Default: piecewise_decay")
add_arg('l2_decay',         float,  3e-5,               "The l2_decay parameter. Default: 3e-5")
add_arg('momentum_rate',    float,  0.9,               "The value of momentum_rate. Default: 0.9")
add_arg('ratio',            float,  0.55,               "The ratio to set zeros, the smaller part bounded by the ratio will be zeros. Default: 0.55")
add_arg('pruning_mode',            str,  'ratio',               "the pruning mode: whether by ratio or by threshold. Default: ratio")
add_arg('threshold',            float,  0.01,               "The threshold to set zeros. Default: 0.01")
add_arg('num_epochs',       int,  120,               "The number of total epochs. Default: 120")
M
minghaoBD 已提交
36
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
37 38 39
add_arg('data',             str, "imagenet",                 "Which data to use. 'cifar10' or 'imagenet'. Default: imagenet")
add_arg('log_period',       int, 100,                 "Log period in batches. Default: 100")
add_arg('test_period',      int, 5,                 "Test period in epoches. Default: 5")
40
add_arg('pretrained_model', str, None,              "The pretrained model the load. Default: None.")
41 42
add_arg('checkpoint',       str, None,              "The checkpoint path to resume training. Default: None.")
add_arg('model_path',       str, "./models",         "The path to save model. Default: ./models")
M
minghaoBD 已提交
43
add_arg('model_period',     int, 10,             "The period to save model in epochs.")
44 45 46 47 48 49 50 51 52
add_arg('last_epoch',     int, -1,             "The last epoch we'll train from. Default: -1")
add_arg('num_workers',     int, 16,             "number of workers when loading dataset. Default: 16")
add_arg('stable_epochs',    int, 0,              "The epoch numbers used to stablize the model before pruning. Default: 0")
add_arg('pruning_epochs',   int, 60,             "The epoch numbers used to prune the model by a ratio step. Default: 60")
add_arg('tunning_epochs',   int, 60,             "The epoch numbers used to tune the after-pruned models. Default: 60")
add_arg('pruning_steps', int, 100,        "How many times you want to increase your ratio during training. Default: 100")
add_arg('initial_ratio',    float, 0.15,         "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('pruning_strategy', str, 'base',         "Which training strategy to use in pruning, we only support base and gmp for now. Default: base")
add_arg('prune_params_type', str, None,           "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
M
minghaoBD 已提交
53 54 55 56 57 58
# yapf: enable


def piecewise_decay(args, step_per_epoch, model):
    bd = [step_per_epoch * e for e in args.step_epochs]
    lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
59 60 61
    last_iter = (1 + args.last_epoch) * step_per_epoch
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(
        boundaries=bd, values=lr, last_epoch=last_iter)
M
minghaoBD 已提交
62 63 64 65 66 67 68 69 70 71

    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=model.parameters())
    return optimizer, learning_rate


def cosine_decay(args, step_per_epoch, model):
72
    last_iter = (1 + args.last_epoch) * step_per_epoch
M
minghaoBD 已提交
73
    learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
74 75 76
        learning_rate=args.lr,
        T_max=args.num_epochs * step_per_epoch,
        last_epoch=last_iter)
M
minghaoBD 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    optimizer = paddle.optimizer.Momentum(
        learning_rate=learning_rate,
        momentum=args.momentum_rate,
        weight_decay=paddle.regularizer.L2Decay(args.l2_decay),
        parameters=model.parameters())
    return optimizer, learning_rate


def create_optimizer(args, step_per_epoch, model):
    if args.lr_strategy == "piecewise_decay":
        return piecewise_decay(args, step_per_epoch, model)
    elif args.lr_strategy == "cosine_decay":
        return cosine_decay(args, step_per_epoch, model)


92 93 94 95 96 97 98 99
def create_unstructured_pruner(model, args, configs=None):
    if configs is None:
        return UnstructuredPruner(
            model,
            mode=args.pruning_mode,
            ratio=args.ratio,
            threshold=args.threshold,
            prune_params_type=args.prune_params_type)
100
    else:
101 102 103 104 105 106 107 108
        return GMPUnstructuredPruner(
            model,
            ratio=args.ratio,
            prune_params_type=args.prune_params_type,
            configs=configs)


def compress(args):
109 110 111 112
    if args.use_gpu:
        place = paddle.set_device('gpu')
    else:
        place = paddle.set_device('cpu')
113 114 115 116 117 118

    trainer_num = paddle.distributed.get_world_size()
    use_data_parallel = trainer_num != 1
    if use_data_parallel:
        dist.init_parallel_env()

M
minghaoBD 已提交
119 120 121 122
    train_reader = None
    test_reader = None
    if args.data == "imagenet":
        import imagenet_reader as reader
123 124
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
M
minghaoBD 已提交
125 126 127 128 129 130 131 132 133 134 135 136
        class_dim = 1000
    elif args.data == "cifar10":
        normalize = T.Normalize(
            mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW')
        transform = T.Compose([T.Transpose(), normalize])
        train_dataset = paddle.vision.datasets.Cifar10(
            mode='train', backend='cv2', transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(
            mode='test', backend='cv2', transform=transform)
        class_dim = 10
    else:
        raise ValueError("{} is not supported.".format(args.data))
137 138 139 140

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

M
minghaoBD 已提交
141 142
    train_loader = paddle.io.DataLoader(
        train_dataset,
143 144
        places=place,
        batch_sampler=batch_sampler,
M
minghaoBD 已提交
145 146 147
        return_list=True,
        num_workers=args.num_workers,
        use_shared_memory=True)
148

M
minghaoBD 已提交
149 150
    valid_loader = paddle.io.DataLoader(
        val_dataset,
151
        places=place,
M
minghaoBD 已提交
152 153
        drop_last=False,
        return_list=True,
M
minghaoBD 已提交
154
        batch_size=args.batch_size_for_validation,
M
minghaoBD 已提交
155 156
        shuffle=False,
        use_shared_memory=True)
157 158
    step_per_epoch = int(
        np.ceil(len(train_dataset) / args.batch_size / ParallelEnv().nranks))
M
minghaoBD 已提交
159 160
    # model definition
    model = mobilenet_v1(num_classes=class_dim, pretrained=True)
161 162 163
    if ParallelEnv().nranks > 1:
        model = paddle.DataParallel(model)

164
    opt, learning_rate = create_optimizer(args, step_per_epoch, model)
M
minghaoBD 已提交
165

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    if args.checkpoint is not None and args.last_epoch > -1:
        if args.checkpoint.endswith('pdparams'):
            args.checkpoint = args.checkpoint[:-9]
        if args.checkpoint.endswith('pdopt'):
            args.checkpoint = args.checkpoint[:-6]
        model.set_state_dict(paddle.load(args.checkpoint + ".pdparams"))
        opt.set_state_dict(paddle.load(args.checkpoint + ".pdopt"))
    elif args.pretrained_model is not None:
        if args.pretrained_model.endswith('pdparams'):
            args.pretrained_model = args.pretrained_model[:-9]
        if args.pretrained_model.endswith('pdopt'):
            args.pretrained_model = args.pretrained_model[:-6]
        model.set_state_dict(paddle.load(args.pretrained_model + ".pdparams"))

    if args.pruning_strategy == 'gmp':
        # GMP pruner step 0: define configs. No need to do this if you are not using 'gmp'
        configs = {
            'stable_iterations': args.stable_epochs * step_per_epoch,
            'pruning_iterations': args.pruning_epochs * step_per_epoch,
            'tunning_iterations': args.tunning_epochs * step_per_epoch,
            'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
            'pruning_steps': args.pruning_steps,
            'initial_ratio': args.initial_ratio,
        }
    else:
        configs = None

    # GMP pruner step 1: initialize a pruner object
    pruner = create_unstructured_pruner(model, args, configs=configs)

M
minghaoBD 已提交
196
    def test(epoch):
197
        model.eval()
M
minghaoBD 已提交
198 199 200 201 202 203
        acc_top1_ns = []
        acc_top5_ns = []
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
204 205
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)
M
minghaoBD 已提交
206

207
            logits = model(x_data)
M
minghaoBD 已提交
208 209 210
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
M
minghaoBD 已提交
211
            end_time = time.time()
M
minghaoBD 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
                    format(epoch, batch_id,
                           np.mean(acc_top1.numpy()),
                           np.mean(acc_top5.numpy()), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))

        _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
            epoch,
            np.mean(np.array(
                acc_top1_ns, dtype="object")),
            np.mean(np.array(
                acc_top5_ns, dtype="object"))))

    def train(epoch):
229
        model.train()
230 231 232 233 234
        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()

M
minghaoBD 已提交
235
        for batch_id, data in enumerate(train_loader):
236
            train_reader_cost += time.time() - reader_start
M
minghaoBD 已提交
237 238
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
239 240
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)
M
minghaoBD 已提交
241

242
            train_start = time.time()
243
            logits = model(x_data)
M
minghaoBD 已提交
244 245 246
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
247

M
minghaoBD 已提交
248 249
            loss.backward()
            opt.step()
250
            learning_rate.step()
M
minghaoBD 已提交
251
            opt.clear_grad()
252
            # GMP pruner step 2: step() to update ratios and other internal states of the pruner.
M
minghaoBD 已提交
253
            pruner.step()
254

255
            train_run_cost += time.time() - train_start
256
            total_samples += args.batch_size
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274

            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
                    format(epoch, batch_id,
                           opt.get_lr(),
                           np.mean(loss.numpy()),
                           np.mean(acc_top1.numpy()),
                           np.mean(acc_top5.numpy()), train_reader_cost /
                           args.log_period, (train_reader_cost + train_run_cost
                                             ) / args.log_period, total_samples
                           / args.log_period, total_samples / (
                               train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0

            reader_start = time.time()
M
minghaoBD 已提交
275

276
    for i in range(args.last_epoch + 1, args.num_epochs):
M
minghaoBD 已提交
277
        train(i)
278 279 280
        # GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
        pruner.update_params()

281
        if (i + 1) % args.test_period == 0:
M
minghaoBD 已提交
282
            _logger.info(
283
                "The current sparsity of the pruned model is: {}%".format(
284
                    round(100 * UnstructuredPruner.total_sparse(model), 2)))
M
minghaoBD 已提交
285
            test(i)
286

287
        if (i + 1) % args.model_period == 0:
M
minghaoBD 已提交
288
            pruner.update_params()
289
            paddle.save(model.state_dict(),
290
                        os.path.join(args.model_path, "model.pdparams"))
M
minghaoBD 已提交
291
            paddle.save(opt.state_dict(),
292
                        os.path.join(args.model_path, "model.pdopt"))
M
minghaoBD 已提交
293 294 295 296 297 298 299 300 301 302


def main():
    args = parser.parse_args()
    print_arguments(args)
    compress(args)


if __name__ == '__main__':
    main()