create_compressed_program.py 19.8 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
C
ceci3 已提交
16
import numpy as np
C
ceci3 已提交
17 18 19
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
C
ceci3 已提交
20
import paddle.regularizer as regularizer
C
ceci3 已提交
21 22 23 24 25
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node
from ..common import get_logger
from .strategy_config import ProgramInfo
26
from ..common.load_model import load_inference_model
C
ceci3 已提交
27 28 29

_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
Z
zhouzj 已提交
30 31
    'build_distill_program', 'build_quant_program', 'build_prune_program',
    'remove_unused_var_nodes'
C
ceci3 已提交
32 33 34
]


35 36 37 38 39 40 41 42 43 44 45 46
def _create_lr_scheduler(train_config):
    if 'learning_rate' not in train_config:
        raise RuntimeError(
            'No `learning_rate` specified in the configuration file.')
    if isinstance(train_config.get('learning_rate'), float):
        return train_config.get('learning_rate')

    params = train_config.get('learning_rate')
    lr_type = params.pop('type')
    return getattr(optimizer.lr, lr_type)(**params)


C
ceci3 已提交
47 48
def _create_optimizer(train_config):
    """create optimizer"""
C
ceci3 已提交
49 50 51 52
    if 'optimizer_builder' not in train_config:
        train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}

    optimizer_builder = train_config['optimizer_builder']
W
whs 已提交
53 54
    assert isinstance(
        optimizer_builder, dict
55 56
    ), "Value of 'optimizer_builder' in train_config should be dict but got {}".format(
        type(optimizer_builder))
C
ceci3 已提交
57 58 59 60
    if 'grad_clip' in optimizer_builder:
        g_clip_params = optimizer_builder['grad_clip']
        g_clip_type = g_clip_params.pop('type')
        grad_clip = getattr(paddle.nn, g_clip_type)(**g_clip_params)
C
ceci3 已提交
61 62 63
    else:
        grad_clip = None

C
ceci3 已提交
64 65 66 67 68 69 70 71 72 73 74
    ### build regularization
    if 'regularizer' in optimizer_builder:
        reg_params = optimizer_builder['regularizer']
        reg_type = reg_params.pop('type')
        reg = getattr(regularizer, reg_type)(**reg_params)
    elif 'weight_decay' in optimizer_builder:
        reg = optimizer_builder.pop('weight_decay')
    else:
        reg = None

    ### build learning rate
75
    lr = _create_lr_scheduler(train_config)
C
ceci3 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

    ### build optimizer
    optim_params = optimizer_builder['optimizer']
    optim_type = optim_params.pop('type')
    opt = getattr(optimizer, optim_type)(learning_rate=lr,
                                         grad_clip=grad_clip,
                                         weight_decay=reg,
                                         **optim_params)
    return opt, lr


def _get_distill_node(student_program, config):
    node = config.get('node')
    if len(node) == 0:
        return None

    ### the type of node is list or list(list)
    if isinstance(node[0], list):
        test_node = node[0][0]
    else:
        test_node = node[0]
    try:
        test_var = student_program.global_block().var(test_node)
        distill_node_pair = []
        if isinstance(node[0], list):
            for n_list in node:
                tmp_node_pair = []
                for n in n_list:
                    tmp_node_pair.append('teacher_' + n)
                    tmp_node_pair.append(n)
                distill_node_pair.append(tmp_node_pair)
        else:
            for n in node:
                distill_node_pair.append('teacher_' + n)
                distill_node_pair.append(n)
        return distill_node_pair
    except:
        return node
C
ceci3 已提交
114 115 116


def _parse_distill_loss(distill_node_pair,
C
ceci3 已提交
117
                        distill_loss='l2',
C
ceci3 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
                        distill_lambda=1.0):
    """parse distill loss config"""
    loss_dist = 0.0
    losses = []
    if isinstance(distill_node_pair[0], str):
        assert isinstance(distill_loss, str)
        assert isinstance(distill_lambda, float)
        distill_node_pair = [distill_node_pair]
        distill_loss = [distill_loss]
        distill_lambda = [distill_lambda]

    assert len(distill_node_pair) == len(distill_loss)
    assert len(distill_node_pair) == len(distill_lambda)
    for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
        tmp_loss = 0.0
        _logger.info("train config.distill_node_pair: {}".format(node, loss,
                                                                 lam))
        assert len(node) % 2 == 0, \
            "distill_node_pair config wrong, the length needs to be an even number"
        for i in range(len(node) // 2):
            tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
        loss_dist += lam * tmp_loss
        losses.append(tmp_loss)

    return loss_dist, losses


def _load_program_and_merge(executor,
                            place,
                            train_program,
                            config,
                            model_dir,
                            model_filename,
                            params_filename,
                            teacher_idx=None,
                            feed_target_names=None):
C
Chang Xu 已提交
154 155
    scope = paddle.static.global_scope()
    new_scope = paddle.static.Scope()
C
ceci3 已提交
156

C
ceci3 已提交
157 158
    if params_filename == 'None':
        params_filename = None
C
ceci3 已提交
159 160 161 162 163 164 165 166 167 168 169 170

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )

    with paddle.static.scope_guard(new_scope):
        [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= (load_inference_model( \
            model_dir, \
            model_filename=model_filename, \
            params_filename=params_filename, \
            executor=executor))
C
ceci3 已提交
171 172 173 174 175 176 177 178

    _remove_fetch_node(teacher_program)

    if teacher_idx == None or teacher_idx == 1:
        test_program = train_program.clone(for_test=True)

    data_name_map = {}

C
ceci3 已提交
179 180 181
    merge_feed = (
        sorted(feed_target_names) == sorted(teacher_feed_target_names))
    if merge_feed == True:
C
ceci3 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194
        for i, name in enumerate(feed_target_names):
            data_name_map[teacher_feed_target_names[i]] = name

    if teacher_idx is None:
        teacher_name_prefix = 'teacher_'
    else:
        teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx))

    merge(
        teacher_program,
        train_program,
        data_name_map,
        place,
C
Chang Xu 已提交
195
        teacher_scope=new_scope,
C
ceci3 已提交
196
        name_prefix=teacher_name_prefix,
C
ceci3 已提交
197
        merge_feed=merge_feed)
C
ceci3 已提交
198 199 200 201 202 203 204 205 206 207 208 209
    if teacher_idx == None or teacher_idx == 1:
        return train_program, test_program, data_name_map
    else:
        return train_program, None, data_name_map


def build_distill_program(executor,
                          place,
                          config,
                          train_config,
                          train_program_info=None,
                          pruner=None,
C
ceci3 已提交
210 211
                          dist_strategy=None,
                          default_distill_node_pair=None):
C
ceci3 已提交
212 213 214
    """build distill program with infermodel"""
    startup_program = paddle.static.Program()
    if train_program_info is None:
C
ceci3 已提交
215
        [train_program, feed_target_names, fetch_targets]= (load_inference_model( \
C
ceci3 已提交
216
            path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \
C
ceci3 已提交
217
            executor=executor))
C
ceci3 已提交
218 219 220 221 222 223
        train_program = recover_inference_program(train_program)
    else:
        train_program = train_program_info.program
        feed_target_names = train_program_info.feed_target_names
        fetch_targets = train_program_info.fetch_targets

C
ceci3 已提交
224 225 226
    distill_node_pair = _get_distill_node(train_program,
                                          config) or default_distill_node_pair

C
ceci3 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
    teacher_model_dir = config[
        "teacher_model_dir"] if "teacher_model_dir" in config else config[
            "teacher_model_path_prefix"]
    if isinstance(teacher_model_dir, list):
        for tea_idx in range(len(teacher_model_dir)):
            model_filename = config["teacher_model_filename"][
                tea_idx] if "teacher_model_filename" in config else None
            params_filename = config["teacher_params_filename"][
                tea_idx] if "teacher_params_filename" in config else None
            if tea_idx == 0:
                train_program, test_program, data_name_map = _load_program_and_merge(
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)
            else:
                train_program, _, data_name_map = _load_program_and_merge(
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)

    else:
        model_filename = config[
            "teacher_model_filename"] if "teacher_model_filename" in config else None
        params_filename = config[
            "teacher_params_filename"] if "teacher_params_filename" in config else None
        train_program, test_program, data_name_map = _load_program_and_merge(
            executor,
            place,
            train_program,
            config,
            teacher_model_dir,
            model_filename,
            params_filename,
            teacher_idx=None,
            feed_target_names=feed_target_names)
    # all feed node should set stop_gradient is False, for using pact quant algo.
    for var in train_program.list_vars():
        if var.name in data_name_map.values() or var.name in data_name_map.keys(
        ):
            var.stop_gradient = False

    train_fetch_list = []
    with paddle.static.program_guard(train_program, startup_program):
        with paddle.utils.unique_name.guard('merge'):
283
            optimizer, learning_rate = _create_optimizer(train_config)
C
ceci3 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316

            if train_config.get('use_fleet'):
                optimizer = fleet.distributed_optimizer(optimizer,
                                                        dist_strategy)
            else:
                if train_config.get('amp_config') is not None:
                    custom_white_list = train_config['amp_config'].get(
                        'custom_white_list', None)
                    if custom_white_list is not None:
                        train_config['amp_config'].pop('custom_white_list')

                    custom_black_list = train_config['amp_config'].get(
                        'custom_black_list', None)
                    if custom_black_list is not None:
                        train_config['amp_config'].pop('custom_black_list')

                    custom_black_varnames = train_config['amp_config'].get(
                        'custom_black_varnames', None)
                    if custom_black_varnames is not None:
                        train_config['amp_config'].pop('custom_black_varnames')

                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_white_list=custom_white_list,
                        custom_black_list=custom_black_list,
                        custom_black_varnames=custom_black_varnames)
                    optimizer = paddle.static.amp.decorate(
                        optimizer=optimizer,
                        amp_lists=amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        **train_config['amp_config'])

            distill_loss, losses = _parse_distill_loss(
C
ceci3 已提交
317 318 319
                distill_node_pair,
                config.get('loss') or 'l2',  ### default loss is l2
                config.get('alpha') or 1.0)  ### default alpha is 1.0
C
ceci3 已提交
320 321 322
            loss = paddle.mean(distill_loss)
            loss.stop_gradient = False

C
ceci3 已提交
323 324 325
            if 'prune_params_name' in config:  ### prune
                if 'pruned_ratio' not in config and not train_config.get(
                        'use_fleet'):  ### asp
C
ceci3 已提交
326 327 328 329 330 331 332 333 334 335 336
                    optimizer = pruner.decorate(optimizer)
                optimizer.minimize(loss)
            elif 'prune_strategy' in config:  ###unstructure prune
                optimizer.minimize(loss, no_grad_set=pruner.no_grad_set)
            else:
                optimizer.minimize(loss)

            train_fetch_list.append(loss)

    train_program_info = ProgramInfo(startup_program, train_program,
                                     feed_target_names, train_fetch_list,
337
                                     optimizer, learning_rate)
C
ceci3 已提交
338 339 340 341 342 343 344 345 346 347 348
    test_program_info = ProgramInfo(startup_program, test_program,
                                    feed_target_names, fetch_targets)
    return train_program_info, test_program_info


def build_quant_program(executor, place, config, train_program_info,
                        test_program_info):
    scope = paddle.static.global_scope()

    assert isinstance(config, dict), "quant config must be dict"

C
ceci3 已提交
349
    use_pact = config.pop("use_pact")
C
ceci3 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
    if use_pact:
        act_preprocess_func = pact
        optimizer_func = get_pact_optimizer
        pact_executor = executor
    else:
        act_preprocess_func = None
        optimizer_func = None
        pact_executor = None

    test_program = quant_aware(
        test_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        for_test=True)

    train_program = quant_aware(
        train_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=pact_executor,
        for_test=False,
        return_program=True)

    train_program_info.program = train_program
    test_program_info.program = test_program
    return train_program_info, test_program_info, config


C
ceci3 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
def _get_label_info(dataloader, feed_target_names):
    label_info = {}
    for data in dataloader():
        for key, value in data[0].items():
            if key in feed_target_names:
                continue
            label_info['name'] = key
            label_info['dtype'] = np.array(value).dtype
            label_info['shape'] = list(np.array(value).shape)
            label_info['shape'][0] = -1
            break
        break
    return label_info


def build_prune_program(executor,
                        place,
                        config,
                        train_program_info,
                        strategy,
                        patterns,
                        eval_dataloader=None):
C
ceci3 已提交
407
    if strategy.startswith('unstructure'):
C
ceci3 已提交
408 409 410 411 412
        from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
        if config["prune_strategy"] is None:
            pruner = UnstructuredPruner(
                train_program_info.program,
                mode=config['prune_mode'],
C
ceci3 已提交
413
                ratio=config['ratio'],
C
ceci3 已提交
414 415 416 417 418 419 420
                threshold=config['threshold'],
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'], )
        elif config["prune_strategy"] == "gmp":
            pruner = GMPUnstructuredPruner(
                train_program_info.program,
C
ceci3 已提交
421
                ratio=config['ratio'],
C
ceci3 已提交
422 423 424
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'],
C
ceci3 已提交
425
                configs=config['gmp_config'])
C
ceci3 已提交
426 427 428 429
    elif strategy.startswith('channel_prune'):
        from ..prune import Pruner
        pruner = Pruner(config["criterion"])
        params = []
W
whs 已提交
430
        original_shapes = {}
C
ceci3 已提交
431 432 433 434 435
        ### TODO(ceci3): set default prune weight
        for param in train_program_info.program.global_block().all_parameters():
            if config['prune_params_name'] is not None and param.name in config[
                    'prune_params_name']:
                params.append(param.name)
W
whs 已提交
436
                original_shapes[param.name] = param.shape
C
ceci3 已提交
437 438 439 440 441 442 443

        pruned_program, _, _ = pruner.prune(
            train_program_info.program,
            paddle.static.global_scope(),
            params=params,
            ratios=[config['pruned_ratio']] * len(params),
            place=place)
W
whs 已提交
444 445 446 447
        _logger.info(
            "####################channel pruning##########################")
        for param in pruned_program.global_block().all_parameters():
            if param.name in original_shapes:
448 449
                _logger.info("{}, from {} to {}".format(
                    param.name, original_shapes[param.name], param.shape))
W
whs 已提交
450 451
        _logger.info(
            "####################channel pruning end##########################")
C
ceci3 已提交
452 453 454 455 456 457 458 459
        train_program_info.program = pruned_program

    elif strategy.startswith('asp'):
        from paddle.static import sparsity
        pruner = sparsity
        excluded_params_name = []
        ### TODO(ceci3): set default prune weight
        for param in train_program_info.program.global_block().all_parameters():
M
minghaoBD 已提交
460 461 462 463 464
            if config['prune_params_name'] is not None:
                if param.name not in config['prune_params_name']:
                    excluded_params_name.append(param.name)
                else:
                    pruner.add_supported_layer(param.name)
C
ceci3 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
            if "teacher_" in param.name:
                excluded_params_name.append(param.name)
        pruner.set_excluded_layers(train_program_info.program,
                                   excluded_params_name)
    elif strategy.startswith('transformer_prune'):
        from .transformer_pruner import TransformerPruner
        assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
        label_info = _get_label_info(eval_dataloader,
                                     train_program_info.feed_target_names)
        assert len(label_info) != 0, \
            "maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
        pruner = TransformerPruner(
            executor,
            place,
            train_program_info.program,
            patterns,
            label_info,
            width_mult=(1.0 - config['pruned_ratio']),
            dataloader=eval_dataloader,
            fetch_targets=train_program_info.fetch_targets)
        pruned_program = pruner.prune()
        train_program_info.program = pruned_program
C
ceci3 已提交
487
    else:
C
ceci3 已提交
488 489 490
        raise NotImplementedError(
            "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
            format(config['prune_algo']))
C
ceci3 已提交
491 492

    return pruner, train_program_info
Z
zhouzj 已提交
493 494 495 496 497 498 499 500 501 502


def remove_unused_var_nodes(program):
    '''
    This function is called before saving the sparse model to remove redundant nodes.
    Args:
        program(paddle.static.Program): The sparse model to be saved.
    Returns:
        program(paddle.static.Program): The sparse model.
    '''
C
ceci3 已提交
503
    from paddle.framework import core
Z
zhouzj 已提交
504 505 506 507 508 509 510 511 512 513 514
    from paddle.fluid.framework import IrGraph
    graph = IrGraph(core.Graph(program.desc), for_test=True)
    removed_nodes = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            if '_mask' in input_node.name():
                removed_nodes.add(op_node)
    graph.safe_remove_nodes(removed_nodes)
    program = graph.to_program()
    return program