prune.py 15.5 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Q
qingqing01 已提交
19 20 21 22 23 24 25
import os, sys

# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

W
whs 已提交
26 27 28 29 30 31
import time
import numpy as np
import datetime
from collections import deque
from paddleslim.prune import Pruner
from paddleslim.analysis import flops
32
import paddle
W
whs 已提交
33
from paddle import fluid
Q
qingqing01 已提交
34

W
whs 已提交
35 36 37 38 39
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

K
Kaipeng Deng 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
try:
    from ppdet.experimental import mixed_precision_context
    from ppdet.core.workspace import load_config, merge_config, create
    from ppdet.data.reader import create_reader
    from ppdet.utils import dist_utils
    from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
    from ppdet.utils.stats import TrainingStats
    from ppdet.utils.cli import ArgsParser
    from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
    import ppdet.utils.checkpoint as checkpoint
except ImportError as e:
    if sys.argv[0].find('static') >= 0:
        logger.error("Importing ppdet failed when running static model "
                     "with error: {}\n"
                     "please try:\n"
                     "\t1. run static model under PaddleDetection/static "
                     "directory\n"
                     "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
                     "dynamic version firstly.".format(e))
        sys.exit(-1)
    else:
        raise e

W
whs 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75

def main():
    env = os.environ
    FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
    if FLAGS.dist:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        import random
        local_seed = (99 + trainer_id)
        random.seed(local_seed)
        np.random.seed(local_seed)

    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
76
    check_config(cfg)
W
whs 已提交
77 78 79 80
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
    # check if paddlepaddle version is satisfied
    check_version()
81 82

    main_arch = cfg.architecture
W
whs 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

    if cfg.use_gpu:
        devices_num = fluid.core.get_cuda_device_count()
    else:
        devices_num = int(os.environ.get('CPU_NUM', 1))

    if 'FLAGS_selected_gpus' in env:
        device_id = int(env['FLAGS_selected_gpus'])
    else:
        device_id = 0
    place = fluid.CUDAPlace(device_id) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    lr_builder = create('LearningRate')
    optim_builder = create('OptimizerBuilder')

    # build program
    startup_prog = fluid.Program()
    train_prog = fluid.Program()
    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            model = create(main_arch)
            if FLAGS.fp16:
                assert (getattr(model.backbone, 'norm_type', None)
                        != 'affine_channel'), \
                    '--fp16 currently does not support affine channel, ' \
                    ' please modify backbone settings to use batch norm'

            with mixed_precision_context(FLAGS.loss_scale, FLAGS.fp16) as ctx:
                inputs_def = cfg['TrainReader']['inputs_def']
                feed_vars, train_loader = model.build_inputs(**inputs_def)
                train_fetches = model.train(feed_vars)
                loss = train_fetches['loss']
                if FLAGS.fp16:
                    loss *= ctx.get_loss_scale_var()
                lr = lr_builder()
                optimizer = optim_builder(lr)
                optimizer.minimize(loss)
                if FLAGS.fp16:
                    loss /= ctx.get_loss_scale_var()

    # parse train fetches
    train_keys, train_values, _ = parse_fetches(train_fetches)
    train_values.append(lr)

    if FLAGS.print_params:
129 130
        param_delimit_str = '-' * 20 + "All parameters in current graph" + '-' * 20
        print(param_delimit_str)
W
whs 已提交
131 132
        for block in train_prog.blocks:
            for param in block.all_parameters():
133 134 135
                print("parameter name: {}\tshape: {}".format(param.name,
                                                             param.shape))
        print('-' * len(param_delimit_str))
W
whs 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148
        return

    if FLAGS.eval:
        eval_prog = fluid.Program()
        with fluid.program_guard(eval_prog, startup_prog):
            with fluid.unique_name.guard():
                model = create(main_arch)
                inputs_def = cfg['EvalReader']['inputs_def']
                feed_vars, eval_loader = model.build_inputs(**inputs_def)
                fetches = model.eval(feed_vars)
        eval_prog = eval_prog.clone(True)

        eval_reader = create_reader(cfg.EvalReader)
149 150
        # When iterable mode, set set_sample_list_generator(eval_reader, place)
        eval_loader.set_sample_list_generator(eval_reader)
W
whs 已提交
151 152 153 154 155 156

        # parse eval fetches
        extra_keys = []
        if cfg.metric == 'COCO':
            extra_keys = ['im_info', 'im_id', 'im_shape']
        if cfg.metric == 'VOC':
K
Kaipeng Deng 已提交
157
            extra_keys = ['gt_bbox', 'gt_class', 'is_difficult']
W
whs 已提交
158
        if cfg.metric == 'WIDERFACE':
K
Kaipeng Deng 已提交
159
            extra_keys = ['im_id', 'im_shape', 'gt_bbox']
W
whs 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
                                                         extra_keys)

    # compile program for multi-devices
    build_strategy = fluid.BuildStrategy()
    build_strategy.fuse_all_optimizer_ops = False
    build_strategy.fuse_elewise_add_act_ops = True
    # only enable sync_bn in multi GPU devices
    sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
    build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
        and cfg.use_gpu

    exec_strategy = fluid.ExecutionStrategy()
    # iteration number when CompiledProgram tries to drop local execution scopes.
    # Set it to be 1 to save memory usages, so that unused variables in
    # local execution scopes can be deleted after each iteration.
    exec_strategy.num_iteration_per_drop_scope = 1
    if FLAGS.dist:
        dist_utils.prepare_for_multi_process(exe, build_strategy, startup_prog,
                                             train_prog)
        exec_strategy.num_threads = 1

    exe.run(startup_prog)

    fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'

    start_iter = 0
K
Kaipeng Deng 已提交
187
    if cfg.pretrain_weights:
188
        checkpoint.load_params(exe, train_prog, cfg.pretrain_weights)
W
whs 已提交
189 190

    pruned_params = FLAGS.pruned_params
191 192
    assert FLAGS.pruned_params is not None, \
        "FLAGS.pruned_params is empty!!! Please set it by '--pruned_params' option."
W
whs 已提交
193 194
    pruned_params = FLAGS.pruned_params.strip().split(",")
    logger.info("pruned params: {}".format(pruned_params))
195
    pruned_ratios = [float(n) for n in FLAGS.pruned_ratios.strip().split(",")]
W
whs 已提交
196
    logger.info("pruned ratios: {}".format(pruned_ratios))
197 198 199 200 201
    assert len(pruned_params) == len(pruned_ratios), \
        "The length of pruned params and pruned ratios should be equal."
    assert (pruned_ratios > [0] * len(pruned_ratios) and
            pruned_ratios < [1] * len(pruned_ratios)
            ), "The elements of pruned ratios should be in range (0, 1)."
W
whs 已提交
202

203 204 205
    assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \
            "unsupported prune criterion {}".format(FLAGS.prune_criterion)
    pruner = Pruner(criterion=FLAGS.prune_criterion)
W
whs 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
    train_prog = pruner.prune(
        train_prog,
        fluid.global_scope(),
        params=pruned_params,
        ratios=pruned_ratios,
        place=place,
        only_graph=False)[0]

    compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
        loss_name=loss.name,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)

    if FLAGS.eval:

        base_flops = flops(eval_prog)
        eval_prog = pruner.prune(
            eval_prog,
            fluid.global_scope(),
            params=pruned_params,
            ratios=pruned_ratios,
            place=place,
            only_graph=True)[0]
        pruned_flops = flops(eval_prog)
230 231 232
        logger.info("FLOPs -{}; total FLOPs: {}; pruned FLOPs: {}".format(
            float(base_flops - pruned_flops) / base_flops, base_flops,
            pruned_flops))
233
        compiled_eval_prog = fluid.CompiledProgram(eval_prog)
W
whs 已提交
234

K
Kaipeng Deng 已提交
235 236 237 238
    if FLAGS.resume_checkpoint:
        checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
        start_iter = checkpoint.global_step()

W
whs 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251
    train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
                                 devices_num, cfg)
    train_loader.set_sample_list_generator(train_reader, place)

    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

    # if map_type not set, use default 11point, only use in VOC eval
    map_type = cfg.map_type if 'map_type' in cfg else '11point'

252
    train_stats = TrainingStats(cfg.log_iter, train_keys)
W
whs 已提交
253 254 255 256 257 258
    train_loader.start()
    start_time = time.time()
    end_time = time.time()

    cfg_name = os.path.basename(FLAGS.config).split('.')[0]
    save_dir = os.path.join(cfg.save_dir, cfg_name)
259
    time_stat = deque(maxlen=cfg.log_iter)
W
whs 已提交
260 261
    best_box_ap_list = [0.0, 0]  #[map, iter]

走神的阿圆's avatar
走神的阿圆 已提交
262 263 264 265 266 267
    # use VisualDL to log data
    if FLAGS.use_vdl:
        from visualdl import LogWriter
        vdl_writer = LogWriter(FLAGS.vdl_log_dir)
        vdl_loss_step = 0
        vdl_mAP_step = 0
W
whs 已提交
268 269 270

    if FLAGS.eval:
        resolution = None
271
        if 'Mask' in cfg.architecture:
W
whs 已提交
272
            resolution = model.mask_head.resolution
273 274 275 276 277 278 279 280 281 282
        # evaluation
        results = eval_run(
            exe,
            compiled_eval_prog,
            eval_loader,
            eval_keys,
            eval_values,
            eval_cls,
            cfg,
            resolution=resolution)
W
whs 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
        dataset = cfg['EvalReader']['dataset']
        box_ap_stats = eval_results(
            results,
            cfg.metric,
            cfg.num_classes,
            resolution,
            is_bbox_normalized,
            FLAGS.output_eval,
            map_type,
            dataset=dataset)

    for it in range(start_iter, cfg.max_iters):
        start_time = end_time
        end_time = time.time()
        time_stat.append(end_time - start_time)
        time_cost = np.mean(time_stat)
        eta_sec = (cfg.max_iters - it) * time_cost
        eta = str(datetime.timedelta(seconds=int(eta_sec)))
        outs = exe.run(compiled_train_prog, fetch_list=train_values)
        stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}

走神的阿圆's avatar
走神的阿圆 已提交
304 305
        # use VisualDL to log loss
        if FLAGS.use_vdl:
W
whs 已提交
306 307
            if it % cfg.log_iter == 0:
                for loss_name, loss_value in stats.items():
走神的阿圆's avatar
走神的阿圆 已提交
308 309
                    vdl_writer.add_scalar(loss_name, loss_value, vdl_loss_step)
                vdl_loss_step += 1
W
whs 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324

        train_stats.update(stats)
        logs = train_stats.log()
        if it % cfg.log_iter == 0 and (not FLAGS.dist or trainer_id == 0):
            strs = 'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'.format(
                it, np.mean(outs[-1]), logs, time_cost, eta)
            logger.info(strs)

        if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
           and (not FLAGS.dist or trainer_id == 0):
            save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
            checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))

            if FLAGS.eval:
                # evaluation
W
whs 已提交
325 326 327
                resolution = None
                if 'Mask' in cfg.architecture:
                    resolution = model.mask_head.resolution
328 329 330 331 332 333 334
                results = eval_run(
                    exe,
                    compiled_eval_prog,
                    eval_loader,
                    eval_keys,
                    eval_values,
                    eval_cls,
W
whs 已提交
335 336
                    cfg=cfg,
                    resolution=resolution)
W
whs 已提交
337
                box_ap_stats = eval_results(
K
Kaipeng Deng 已提交
338 339 340 341 342 343 344 345
                    results,
                    cfg.metric,
                    cfg.num_classes,
                    resolution,
                    is_bbox_normalized,
                    FLAGS.output_eval,
                    map_type,
                    dataset=dataset)
W
whs 已提交
346

走神的阿圆's avatar
走神的阿圆 已提交
347 348 349 350
                # use VisualDL to log mAP
                if FLAGS.use_vdl:
                    vdl_writer.add_scalar("mAP", box_ap_stats[0], vdl_mAP_step)
                    vdl_mAP_step += 1
W
whs 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363

                if box_ap_stats[0] > best_box_ap_list[0]:
                    best_box_ap_list[0] = box_ap_stats[0]
                    best_box_ap_list[1] = it
                    checkpoint.save(exe, train_prog,
                                    os.path.join(save_dir, "best_model"))
                logger.info("Best test box ap: {}, in iter: {}".format(
                    best_box_ap_list[0], best_box_ap_list[1]))

    train_loader.reset()


if __name__ == '__main__':
364
    enable_static_mode()
W
whs 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    parser = ArgsParser()
    parser.add_argument(
        "-r",
        "--resume_checkpoint",
        default=None,
        type=str,
        help="Checkpoint path for resuming training.")
    parser.add_argument(
        "--fp16",
        action='store_true',
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--loss_scale",
        default=8.,
        type=float,
        help="Mixed precision training loss scale.")
    parser.add_argument(
        "--eval",
        action='store_true',
        default=False,
        help="Whether to perform evaluation in train")
    parser.add_argument(
        "--output_eval",
        default=None,
        type=str,
        help="Evaluation directory, default is current directory.")
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
393
        "--use_vdl",
W
whs 已提交
394 395
        type=bool,
        default=False,
走神的阿圆's avatar
走神的阿圆 已提交
396
        help="whether to record the data to VisualDL.")
W
whs 已提交
397
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
398
        '--vdl_log_dir',
W
whs 已提交
399
        type=str,
走神的阿圆's avatar
走神的阿圆 已提交
400 401
        default="vdl_log_dir/scalar",
        help='VisualDL logging directory for scalar.')
W
whs 已提交
402 403 404 405 406 407 408 409 410

    parser.add_argument(
        "-p",
        "--pruned_params",
        default=None,
        type=str,
        help="The parameters to be pruned when calculating sensitivities.")
    parser.add_argument(
        "--pruned_ratios",
411
        default=None,
W
whs 已提交
412
        type=str,
413 414
        help="The ratios pruned iteratively for each parameter when calculating sensitivities."
    )
W
whs 已提交
415 416 417 418 419 420
    parser.add_argument(
        "-P",
        "--print_params",
        default=False,
        action='store_true',
        help="Whether to only print the parameters' names and shapes.")
421 422 423 424 425 426
    parser.add_argument(
        "--prune_criterion",
        default='l1_norm',
        type=str,
        help="criterion function type for channels sorting in pruning, can be set " \
             "as 'l1_norm' or 'geometry_median' currently, default 'l1_norm'")
W
whs 已提交
427 428
    FLAGS = parser.parse_args()
    main()