create_compressed_program.py 22.7 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
C
ceci3 已提交
16
import numpy as np
C
ceci3 已提交
17 18 19
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
C
ceci3 已提交
20
import paddle.regularizer as regularizer
C
ceci3 已提交
21 22 23 24 25
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node
from ..common import get_logger
from .strategy_config import ProgramInfo
26
from ..common.load_model import load_inference_model
C
ceci3 已提交
27 28 29

_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
Z
zhouzj 已提交
30 31
    'build_distill_program', 'build_quant_program', 'build_prune_program',
    'remove_unused_var_nodes'
C
ceci3 已提交
32 33 34
]


35 36 37 38 39 40 41 42 43 44 45 46
def _create_lr_scheduler(train_config):
    if 'learning_rate' not in train_config:
        raise RuntimeError(
            'No `learning_rate` specified in the configuration file.')
    if isinstance(train_config.get('learning_rate'), float):
        return train_config.get('learning_rate')

    params = train_config.get('learning_rate')
    lr_type = params.pop('type')
    return getattr(optimizer.lr, lr_type)(**params)


C
ceci3 已提交
47 48
def _create_optimizer(train_config):
    """create optimizer"""
C
ceci3 已提交
49 50 51 52
    if 'optimizer_builder' not in train_config:
        train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}

    optimizer_builder = train_config['optimizer_builder']
W
whs 已提交
53 54
    assert isinstance(
        optimizer_builder, dict
55 56
    ), "Value of 'optimizer_builder' in train_config should be dict but got {}".format(
        type(optimizer_builder))
C
ceci3 已提交
57 58 59 60
    if 'grad_clip' in optimizer_builder:
        g_clip_params = optimizer_builder['grad_clip']
        g_clip_type = g_clip_params.pop('type')
        grad_clip = getattr(paddle.nn, g_clip_type)(**g_clip_params)
C
ceci3 已提交
61 62 63
    else:
        grad_clip = None

C
ceci3 已提交
64 65 66 67 68 69 70 71 72 73 74
    ### build regularization
    if 'regularizer' in optimizer_builder:
        reg_params = optimizer_builder['regularizer']
        reg_type = reg_params.pop('type')
        reg = getattr(regularizer, reg_type)(**reg_params)
    elif 'weight_decay' in optimizer_builder:
        reg = optimizer_builder.pop('weight_decay')
    else:
        reg = None

    ### build learning rate
75
    lr = _create_lr_scheduler(train_config)
C
ceci3 已提交
76 77 78 79 80 81 82 83 84 85 86

    ### build optimizer
    optim_params = optimizer_builder['optimizer']
    optim_type = optim_params.pop('type')
    opt = getattr(optimizer, optim_type)(learning_rate=lr,
                                         grad_clip=grad_clip,
                                         weight_decay=reg,
                                         **optim_params)
    return opt, lr


C
ceci3 已提交
87 88 89 90 91 92 93
def _find_var_from_program(program, var_name):
    for block in program.blocks:
        if block.has_var(var_name):
            return block.var(var_name)
    raise ValueError("var {} not in this program".format(var_name))


C
ceci3 已提交
94 95 96 97 98 99 100 101 102 103 104
def _get_distill_node(student_program, config):
    node = config.get('node')
    if len(node) == 0:
        return None

    ### the type of node is list or list(list)
    if isinstance(node[0], list):
        test_node = node[0][0]
    else:
        test_node = node[0]
    try:
C
ceci3 已提交
105
        test_var = _find_var_from_program(student_program, test_node)
C
ceci3 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        distill_node_pair = []
        if isinstance(node[0], list):
            for n_list in node:
                tmp_node_pair = []
                for n in n_list:
                    tmp_node_pair.append('teacher_' + n)
                    tmp_node_pair.append(n)
                distill_node_pair.append(tmp_node_pair)
        else:
            for n in node:
                distill_node_pair.append('teacher_' + n)
                distill_node_pair.append(n)
        return distill_node_pair
    except:
        return node
C
ceci3 已提交
121 122


C
ceci3 已提交
123 124 125 126 127 128 129 130
def _get_target_node(distill_node):
    targets = []
    for idx, node in enumerate(distill_node):
        if idx % 2 != 0:
            targets.append(node)
    return targets


C
ceci3 已提交
131
def _parse_distill_loss(distill_node_pair,
C
ceci3 已提交
132
                        distill_loss='l2',
C
ceci3 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
                        distill_lambda=1.0):
    """parse distill loss config"""
    loss_dist = 0.0
    losses = []
    if isinstance(distill_node_pair[0], str):
        assert isinstance(distill_loss, str)
        assert isinstance(distill_lambda, float)
        distill_node_pair = [distill_node_pair]
        distill_loss = [distill_loss]
        distill_lambda = [distill_lambda]

    assert len(distill_node_pair) == len(distill_loss)
    assert len(distill_node_pair) == len(distill_lambda)
    for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
        tmp_loss = 0.0
        _logger.info("train config.distill_node_pair: {}".format(node, loss,
                                                                 lam))
        assert len(node) % 2 == 0, \
            "distill_node_pair config wrong, the length needs to be an even number"
        for i in range(len(node) // 2):
            tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
        loss_dist += lam * tmp_loss
        losses.append(tmp_loss)

    return loss_dist, losses


def _load_program_and_merge(executor,
                            place,
                            train_program,
                            config,
                            model_dir,
                            model_filename,
                            params_filename,
C
ceci3 已提交
167
                            distill_node_pair,
C
ceci3 已提交
168 169
                            teacher_idx=None,
                            feed_target_names=None):
C
Chang Xu 已提交
170 171
    scope = paddle.static.global_scope()
    new_scope = paddle.static.Scope()
C
ceci3 已提交
172

C
ceci3 已提交
173 174
    if params_filename == 'None':
        params_filename = None
C
ceci3 已提交
175 176 177 178 179 180 181 182 183 184 185 186

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )

    with paddle.static.scope_guard(new_scope):
        [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= (load_inference_model( \
            model_dir, \
            model_filename=model_filename, \
            params_filename=params_filename, \
            executor=executor))
C
ceci3 已提交
187 188 189

    _remove_fetch_node(teacher_program)

C
ceci3 已提交
190 191
    target_nodes = _get_target_node(distill_node_pair)
    teacher_program = teacher_program._prune(target_nodes)
C
ceci3 已提交
192 193 194

    data_name_map = {}

C
ceci3 已提交
195 196 197
    merge_feed = (
        sorted(feed_target_names) == sorted(teacher_feed_target_names))
    if merge_feed == True:
C
ceci3 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210
        for i, name in enumerate(feed_target_names):
            data_name_map[teacher_feed_target_names[i]] = name

    if teacher_idx is None:
        teacher_name_prefix = 'teacher_'
    else:
        teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx))

    merge(
        teacher_program,
        train_program,
        data_name_map,
        place,
C
Chang Xu 已提交
211
        teacher_scope=new_scope,
C
ceci3 已提交
212
        name_prefix=teacher_name_prefix,
C
ceci3 已提交
213
        merge_feed=merge_feed)
C
ceci3 已提交
214
    if teacher_idx == None or teacher_idx == 1:
C
ceci3 已提交
215
        return train_program, data_name_map
C
ceci3 已提交
216
    else:
C
ceci3 已提交
217
        return train_program, data_name_map
C
ceci3 已提交
218 219 220 221 222 223 224 225


def build_distill_program(executor,
                          place,
                          config,
                          train_config,
                          train_program_info=None,
                          pruner=None,
C
ceci3 已提交
226 227
                          dist_strategy=None,
                          default_distill_node_pair=None):
C
ceci3 已提交
228 229 230
    """build distill program with infermodel"""
    startup_program = paddle.static.Program()
    if train_program_info is None:
C
ceci3 已提交
231
        [train_program, feed_target_names, fetch_targets]= (load_inference_model( \
C
ceci3 已提交
232
            path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \
C
ceci3 已提交
233
            executor=executor))
C
ceci3 已提交
234 235 236 237 238 239
        train_program = recover_inference_program(train_program)
    else:
        train_program = train_program_info.program
        feed_target_names = train_program_info.feed_target_names
        fetch_targets = train_program_info.fetch_targets

C
ceci3 已提交
240 241 242
    distill_node_pair = _get_distill_node(train_program,
                                          config) or default_distill_node_pair

C
ceci3 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
    test_program = train_program.clone(for_test=True)

    target_nodes = _get_target_node(distill_node_pair)

    def _prepend_feed(block, feed_idx, feed_target_names):
        for idx in feed_idx[::-1]:
            block._remove_op(idx)

        feed_var = block.create_var(
            name='feed',
            type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH,
            persistable=True, )

        for i, name in enumerate(feed_target_names):
            out = block.var(name)
            block._prepend_op(
                type='feed',
                inputs={'X': [feed_var]},
                outputs={'Out': [out]},
                attrs={'col': i})

    judge_feed_pos = False
    if train_program.desc.block(0).op(0).type() != 'feed':
        judge_feed_pos = True
    if judge_feed_pos:
        feed_idx = []
        for op in train_program.global_block().ops:
            if op.type == 'feed':
                feed_idx.append(op.idx)
        _prepend_feed(train_program.global_block(), feed_idx, feed_target_names)
    train_program = train_program._prune(target_nodes)

C
ceci3 已提交
275 276 277 278 279 280 281 282 283 284
    teacher_model_dir = config[
        "teacher_model_dir"] if "teacher_model_dir" in config else config[
            "teacher_model_path_prefix"]
    if isinstance(teacher_model_dir, list):
        for tea_idx in range(len(teacher_model_dir)):
            model_filename = config["teacher_model_filename"][
                tea_idx] if "teacher_model_filename" in config else None
            params_filename = config["teacher_params_filename"][
                tea_idx] if "teacher_params_filename" in config else None
            if tea_idx == 0:
C
ceci3 已提交
285
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
286 287 288 289 290 291 292
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
293
                    distill_node_pair,
C
ceci3 已提交
294 295 296
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)
            else:
C
ceci3 已提交
297
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
298 299 300 301 302 303 304
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
305
                    distill_node_pair,
C
ceci3 已提交
306 307 308 309 310 311 312 313
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)

    else:
        model_filename = config[
            "teacher_model_filename"] if "teacher_model_filename" in config else None
        params_filename = config[
            "teacher_params_filename"] if "teacher_params_filename" in config else None
C
ceci3 已提交
314
        train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
315 316 317 318 319 320 321
            executor,
            place,
            train_program,
            config,
            teacher_model_dir,
            model_filename,
            params_filename,
C
ceci3 已提交
322
            distill_node_pair,
C
ceci3 已提交
323 324 325 326 327 328 329 330 331 332 333
            teacher_idx=None,
            feed_target_names=feed_target_names)
    # all feed node should set stop_gradient is False, for using pact quant algo.
    for var in train_program.list_vars():
        if var.name in data_name_map.values() or var.name in data_name_map.keys(
        ):
            var.stop_gradient = False

    train_fetch_list = []
    with paddle.static.program_guard(train_program, startup_program):
        with paddle.utils.unique_name.guard('merge'):
334
            optimizer, learning_rate = _create_optimizer(train_config)
C
ceci3 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367

            if train_config.get('use_fleet'):
                optimizer = fleet.distributed_optimizer(optimizer,
                                                        dist_strategy)
            else:
                if train_config.get('amp_config') is not None:
                    custom_white_list = train_config['amp_config'].get(
                        'custom_white_list', None)
                    if custom_white_list is not None:
                        train_config['amp_config'].pop('custom_white_list')

                    custom_black_list = train_config['amp_config'].get(
                        'custom_black_list', None)
                    if custom_black_list is not None:
                        train_config['amp_config'].pop('custom_black_list')

                    custom_black_varnames = train_config['amp_config'].get(
                        'custom_black_varnames', None)
                    if custom_black_varnames is not None:
                        train_config['amp_config'].pop('custom_black_varnames')

                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_white_list=custom_white_list,
                        custom_black_list=custom_black_list,
                        custom_black_varnames=custom_black_varnames)
                    optimizer = paddle.static.amp.decorate(
                        optimizer=optimizer,
                        amp_lists=amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        **train_config['amp_config'])

            distill_loss, losses = _parse_distill_loss(
C
ceci3 已提交
368 369 370
                distill_node_pair,
                config.get('loss') or 'l2',  ### default loss is l2
                config.get('alpha') or 1.0)  ### default alpha is 1.0
C
ceci3 已提交
371 372 373
            loss = paddle.mean(distill_loss)
            loss.stop_gradient = False

C
ceci3 已提交
374 375 376
            if 'prune_params_name' in config:  ### prune
                if 'pruned_ratio' not in config and not train_config.get(
                        'use_fleet'):  ### asp
C
ceci3 已提交
377 378 379 380 381 382 383 384 385 386 387
                    optimizer = pruner.decorate(optimizer)
                optimizer.minimize(loss)
            elif 'prune_strategy' in config:  ###unstructure prune
                optimizer.minimize(loss, no_grad_set=pruner.no_grad_set)
            else:
                optimizer.minimize(loss)

            train_fetch_list.append(loss)

    train_program_info = ProgramInfo(startup_program, train_program,
                                     feed_target_names, train_fetch_list,
388
                                     optimizer, learning_rate)
C
ceci3 已提交
389 390 391 392 393 394 395 396 397 398 399
    test_program_info = ProgramInfo(startup_program, test_program,
                                    feed_target_names, fetch_targets)
    return train_program_info, test_program_info


def build_quant_program(executor, place, config, train_program_info,
                        test_program_info):
    scope = paddle.static.global_scope()

    assert isinstance(config, dict), "quant config must be dict"

C
ceci3 已提交
400
    use_pact = config.pop("use_pact")
C
ceci3 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
    if use_pact:
        act_preprocess_func = pact
        optimizer_func = get_pact_optimizer
        pact_executor = executor
    else:
        act_preprocess_func = None
        optimizer_func = None
        pact_executor = None

    test_program = quant_aware(
        test_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        for_test=True)

    train_program = quant_aware(
        train_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=pact_executor,
        for_test=False,
        return_program=True)

    train_program_info.program = train_program
    test_program_info.program = test_program
    return train_program_info, test_program_info, config


C
ceci3 已提交
436 437 438
def _get_label_info(dataloader, feed_target_names):
    label_info = {}
    for data in dataloader():
C
ceci3 已提交
439 440 441
        if isinstance(data, list) or isinstance(data, tuple):
            data = data[0]
        for key, value in data.items():
C
ceci3 已提交
442 443 444 445 446 447 448 449 450 451 452
            if key in feed_target_names:
                continue
            label_info['name'] = key
            label_info['dtype'] = np.array(value).dtype
            label_info['shape'] = list(np.array(value).shape)
            label_info['shape'][0] = -1
            break
        break
    return label_info


C
ceci3 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
def _get_chn_prune_params(program):
    params = []
    original_shapes = {}
    for block in program.blocks:
        for op in block.ops:
            if op.type == 'conv2d' and op.attr('groups') == 1:
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
                        original_shapes[inp_name] = var_.shape
    return params, original_shapes


def _get_asp_prune_params(program):
    params = []
    for block in program.blocks:
        for op in block.ops:
            if (op.type == 'conv2d' and op.attr('groups') == 1
                ) or op.type == 'mul' or op.type == 'matmul_v2':
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
    return params


C
ceci3 已提交
480 481 482 483 484 485 486
def build_prune_program(executor,
                        place,
                        config,
                        train_program_info,
                        strategy,
                        patterns,
                        eval_dataloader=None):
C
ceci3 已提交
487
    if strategy.startswith('unstructure'):
C
ceci3 已提交
488 489 490 491 492
        from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
        if config["prune_strategy"] is None:
            pruner = UnstructuredPruner(
                train_program_info.program,
                mode=config['prune_mode'],
C
ceci3 已提交
493
                ratio=config['ratio'],
C
ceci3 已提交
494 495 496 497 498 499 500
                threshold=config['threshold'],
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'], )
        elif config["prune_strategy"] == "gmp":
            pruner = GMPUnstructuredPruner(
                train_program_info.program,
C
ceci3 已提交
501
                ratio=config['ratio'],
C
ceci3 已提交
502 503 504
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'],
C
ceci3 已提交
505
                configs=config['gmp_config'])
C
ceci3 已提交
506 507 508
    elif strategy.startswith('channel_prune'):
        from ..prune import Pruner
        pruner = Pruner(config["criterion"])
C
ceci3 已提交
509 510 511 512 513 514 515 516 517 518 519 520 521
        if config['prune_params_name'] is None:
            params, original_shapes = _get_chn_prune_params(
                train_program_info.program)
        else:
            params = []
            original_shapes = {}
            for param in train_program_info.program.global_block(
            ).all_parameters():
                if config[
                        'prune_params_name'] is not None and param.name in config[
                            'prune_params_name']:
                    params.append(param.name)
                    original_shapes[param.name] = param.shape
C
ceci3 已提交
522 523 524 525 526

        pruned_program, _, _ = pruner.prune(
            train_program_info.program,
            paddle.static.global_scope(),
            params=params,
C
ceci3 已提交
527 528 529
            ratios=[config['pruned_ratio']] * len(params)
            if isinstance(config['pruned_ratio'], float) else
            config['pruned_ratio'],
C
ceci3 已提交
530
            place=place)
W
whs 已提交
531 532
        _logger.info(
            "####################channel pruning##########################")
C
ceci3 已提交
533
        for param in pruned_program.all_parameters():
W
whs 已提交
534
            if param.name in original_shapes:
535 536
                _logger.info("{}, from {} to {}".format(
                    param.name, original_shapes[param.name], param.shape))
W
whs 已提交
537 538
        _logger.info(
            "####################channel pruning end##########################")
C
ceci3 已提交
539 540 541 542 543 544
        train_program_info.program = pruned_program

    elif strategy.startswith('asp'):
        from paddle.static import sparsity
        pruner = sparsity
        excluded_params_name = []
C
ceci3 已提交
545 546 547 548
        if config['prune_params_name'] is None:
            config['prune_params_name'] = _get_asp_prune_params(
                train_program_info.program)

C
ceci3 已提交
549
        for param in train_program_info.program.global_block().all_parameters():
M
minghaoBD 已提交
550 551 552 553 554
            if config['prune_params_name'] is not None:
                if param.name not in config['prune_params_name']:
                    excluded_params_name.append(param.name)
                else:
                    pruner.add_supported_layer(param.name)
C
ceci3 已提交
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
            if "teacher_" in param.name:
                excluded_params_name.append(param.name)
        pruner.set_excluded_layers(train_program_info.program,
                                   excluded_params_name)
    elif strategy.startswith('transformer_prune'):
        from .transformer_pruner import TransformerPruner
        assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
        label_info = _get_label_info(eval_dataloader,
                                     train_program_info.feed_target_names)
        assert len(label_info) != 0, \
            "maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
        pruner = TransformerPruner(
            executor,
            place,
            train_program_info.program,
            patterns,
            label_info,
            width_mult=(1.0 - config['pruned_ratio']),
            dataloader=eval_dataloader,
            fetch_targets=train_program_info.fetch_targets)
        pruned_program = pruner.prune()
        train_program_info.program = pruned_program
C
ceci3 已提交
577
    else:
C
ceci3 已提交
578 579 580
        raise NotImplementedError(
            "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
            format(config['prune_algo']))
C
ceci3 已提交
581 582

    return pruner, train_program_info
Z
zhouzj 已提交
583 584 585 586 587 588 589 590 591 592


def remove_unused_var_nodes(program):
    '''
    This function is called before saving the sparse model to remove redundant nodes.
    Args:
        program(paddle.static.Program): The sparse model to be saved.
    Returns:
        program(paddle.static.Program): The sparse model.
    '''
C
ceci3 已提交
593
    from paddle.framework import core
Z
zhouzj 已提交
594 595 596 597 598 599 600 601 602 603 604
    from paddle.fluid.framework import IrGraph
    graph = IrGraph(core.Graph(program.desc), for_test=True)
    removed_nodes = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            if '_mask' in input_node.name():
                removed_nodes.add(op_node)
    graph.safe_remove_nodes(removed_nodes)
    program = graph.to_program()
    return program