create_compressed_program.py 23.4 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
C
ceci3 已提交
16
import numpy as np
C
ceci3 已提交
17 18 19
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
C
ceci3 已提交
20
import paddle.regularizer as regularizer
C
ceci3 已提交
21 22 23 24 25
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node
from ..common import get_logger
from .strategy_config import ProgramInfo
26
from ..common.load_model import load_inference_model
Z
zhouzj 已提交
27
from ..analysis import flops
C
ceci3 已提交
28 29 30

_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
Z
zhouzj 已提交
31 32
    'build_distill_program', 'build_quant_program', 'build_prune_program',
    'remove_unused_var_nodes'
C
ceci3 已提交
33 34 35
]


36 37 38 39 40 41 42 43 44 45 46 47
def _create_lr_scheduler(train_config):
    if 'learning_rate' not in train_config:
        raise RuntimeError(
            'No `learning_rate` specified in the configuration file.')
    if isinstance(train_config.get('learning_rate'), float):
        return train_config.get('learning_rate')

    params = train_config.get('learning_rate')
    lr_type = params.pop('type')
    return getattr(optimizer.lr, lr_type)(**params)


C
ceci3 已提交
48 49
def _create_optimizer(train_config):
    """create optimizer"""
C
ceci3 已提交
50 51 52 53
    if 'optimizer_builder' not in train_config:
        train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}

    optimizer_builder = train_config['optimizer_builder']
W
whs 已提交
54 55
    assert isinstance(
        optimizer_builder, dict
56 57
    ), "Value of 'optimizer_builder' in train_config should be dict but got {}".format(
        type(optimizer_builder))
C
ceci3 已提交
58 59 60 61
    if 'grad_clip' in optimizer_builder:
        g_clip_params = optimizer_builder['grad_clip']
        g_clip_type = g_clip_params.pop('type')
        grad_clip = getattr(paddle.nn, g_clip_type)(**g_clip_params)
C
ceci3 已提交
62 63 64
    else:
        grad_clip = None

C
ceci3 已提交
65 66 67 68 69 70 71 72 73 74 75
    ### build regularization
    if 'regularizer' in optimizer_builder:
        reg_params = optimizer_builder['regularizer']
        reg_type = reg_params.pop('type')
        reg = getattr(regularizer, reg_type)(**reg_params)
    elif 'weight_decay' in optimizer_builder:
        reg = optimizer_builder.pop('weight_decay')
    else:
        reg = None

    ### build learning rate
76
    lr = _create_lr_scheduler(train_config)
C
ceci3 已提交
77 78 79 80

    ### build optimizer
    optim_params = optimizer_builder['optimizer']
    optim_type = optim_params.pop('type')
Z
zhouzj 已提交
81 82
    opt = getattr(optimizer, optim_type)(
        learning_rate=lr, grad_clip=grad_clip, weight_decay=reg, **optim_params)
C
ceci3 已提交
83 84 85
    return opt, lr


C
ceci3 已提交
86 87 88 89 90 91 92
def _find_var_from_program(program, var_name):
    for block in program.blocks:
        if block.has_var(var_name):
            return block.var(var_name)
    raise ValueError("var {} not in this program".format(var_name))


C
ceci3 已提交
93 94 95 96 97 98 99 100 101 102 103
def _get_distill_node(student_program, config):
    node = config.get('node')
    if len(node) == 0:
        return None

    ### the type of node is list or list(list)
    if isinstance(node[0], list):
        test_node = node[0][0]
    else:
        test_node = node[0]
    try:
C
ceci3 已提交
104
        test_var = _find_var_from_program(student_program, test_node)
C
ceci3 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        distill_node_pair = []
        if isinstance(node[0], list):
            for n_list in node:
                tmp_node_pair = []
                for n in n_list:
                    tmp_node_pair.append('teacher_' + n)
                    tmp_node_pair.append(n)
                distill_node_pair.append(tmp_node_pair)
        else:
            for n in node:
                distill_node_pair.append('teacher_' + n)
                distill_node_pair.append(n)
        return distill_node_pair
    except:
        return node
C
ceci3 已提交
120 121


122 123 124 125 126 127 128 129 130 131
def _get_target_node(distill_node, teacher=False):
    tmp_nodes = set()
    if isinstance(distill_node[0], list):
        for n_list in distill_node:
            for n in n_list:
                tmp_nodes.add(n)
    else:
        for n in distill_node:
            tmp_nodes.add(n)

C
ceci3 已提交
132
    targets = []
133 134 135 136 137
    for node in tmp_nodes:
        if teacher and 'teacher_' in node:
            tmp = node.split('teacher_')[-1]
            targets.append(tmp)
        if not teacher and 'teacher_' not in node:
C
ceci3 已提交
138
            targets.append(node)
139

C
ceci3 已提交
140 141 142
    return targets


C
ceci3 已提交
143
def _parse_distill_loss(distill_node_pair,
C
ceci3 已提交
144
                        distill_loss='l2',
C
ceci3 已提交
145 146 147
                        distill_lambda=1.0):
    """parse distill loss config"""
    loss_dist = 0.0
Z
zhouzj 已提交
148
    losses = {}
C
ceci3 已提交
149 150 151 152 153 154 155 156 157
    if isinstance(distill_node_pair[0], str):
        assert isinstance(distill_loss, str)
        assert isinstance(distill_lambda, float)
        distill_node_pair = [distill_node_pair]
        distill_loss = [distill_loss]
        distill_lambda = [distill_lambda]

    assert len(distill_node_pair) == len(distill_loss)
    assert len(distill_node_pair) == len(distill_lambda)
Z
zhouzj 已提交
158 159 160
    for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
                                    distill_lambda):
        tmp_loss = losses.get(loss_clas, 0.0)
Z
zhouzj 已提交
161 162
        _logger.info(
            "train config.distill_node_pair: {}".format(node, loss_clas, lam))
C
ceci3 已提交
163 164 165
        assert len(node) % 2 == 0, \
            "distill_node_pair config wrong, the length needs to be an even number"
        for i in range(len(node) // 2):
Z
zhouzj 已提交
166 167 168
            tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
        loss_dist += tmp_loss
        losses[loss_clas] = tmp_loss
C
ceci3 已提交
169 170 171 172 173 174 175 176 177 178 179

    return loss_dist, losses


def _load_program_and_merge(executor,
                            place,
                            train_program,
                            config,
                            model_dir,
                            model_filename,
                            params_filename,
C
ceci3 已提交
180
                            distill_node_pair,
C
ceci3 已提交
181 182
                            teacher_idx=None,
                            feed_target_names=None):
C
Chang Xu 已提交
183 184
    scope = paddle.static.global_scope()
    new_scope = paddle.static.Scope()
C
ceci3 已提交
185

C
ceci3 已提交
186 187
    if params_filename == 'None':
        params_filename = None
C
ceci3 已提交
188 189 190 191 192 193 194 195 196 197 198 199

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )

    with paddle.static.scope_guard(new_scope):
        [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= (load_inference_model( \
            model_dir, \
            model_filename=model_filename, \
            params_filename=params_filename, \
            executor=executor))
C
ceci3 已提交
200 201 202

    _remove_fetch_node(teacher_program)

203
    target_nodes = _get_target_node(distill_node_pair, True)
C
ceci3 已提交
204
    teacher_program = teacher_program._prune(target_nodes)
C
ceci3 已提交
205 206 207

    data_name_map = {}

C
ceci3 已提交
208 209 210
    merge_feed = (
        sorted(feed_target_names) == sorted(teacher_feed_target_names))
    if merge_feed == True:
C
ceci3 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223
        for i, name in enumerate(feed_target_names):
            data_name_map[teacher_feed_target_names[i]] = name

    if teacher_idx is None:
        teacher_name_prefix = 'teacher_'
    else:
        teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx))

    merge(
        teacher_program,
        train_program,
        data_name_map,
        place,
C
Chang Xu 已提交
224
        teacher_scope=new_scope,
C
ceci3 已提交
225
        name_prefix=teacher_name_prefix,
C
ceci3 已提交
226
        merge_feed=merge_feed)
C
ceci3 已提交
227
    if teacher_idx == None or teacher_idx == 1:
C
ceci3 已提交
228
        return train_program, data_name_map
C
ceci3 已提交
229
    else:
C
ceci3 已提交
230
        return train_program, data_name_map
C
ceci3 已提交
231 232 233 234 235 236 237 238


def build_distill_program(executor,
                          place,
                          config,
                          train_config,
                          train_program_info=None,
                          pruner=None,
C
ceci3 已提交
239 240
                          dist_strategy=None,
                          default_distill_node_pair=None):
C
ceci3 已提交
241 242 243
    """build distill program with infermodel"""
    startup_program = paddle.static.Program()
    if train_program_info is None:
C
ceci3 已提交
244
        [train_program, feed_target_names, fetch_targets]= (load_inference_model( \
C
ceci3 已提交
245
            path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \
C
ceci3 已提交
246
            executor=executor))
C
ceci3 已提交
247 248 249 250 251 252
        train_program = recover_inference_program(train_program)
    else:
        train_program = train_program_info.program
        feed_target_names = train_program_info.feed_target_names
        fetch_targets = train_program_info.fetch_targets

C
ceci3 已提交
253 254 255
    distill_node_pair = _get_distill_node(train_program,
                                          config) or default_distill_node_pair

C
ceci3 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    test_program = train_program.clone(for_test=True)

    target_nodes = _get_target_node(distill_node_pair)

    def _prepend_feed(block, feed_idx, feed_target_names):
        for idx in feed_idx[::-1]:
            block._remove_op(idx)

        feed_var = block.create_var(
            name='feed',
            type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH,
            persistable=True, )

        for i, name in enumerate(feed_target_names):
            out = block.var(name)
            block._prepend_op(
                type='feed',
                inputs={'X': [feed_var]},
                outputs={'Out': [out]},
                attrs={'col': i})

    judge_feed_pos = False
    if train_program.desc.block(0).op(0).type() != 'feed':
        judge_feed_pos = True
    if judge_feed_pos:
        feed_idx = []
        for op in train_program.global_block().ops:
            if op.type == 'feed':
                feed_idx.append(op.idx)
        _prepend_feed(train_program.global_block(), feed_idx, feed_target_names)
    train_program = train_program._prune(target_nodes)

C
ceci3 已提交
288 289 290 291 292 293 294 295 296 297
    teacher_model_dir = config[
        "teacher_model_dir"] if "teacher_model_dir" in config else config[
            "teacher_model_path_prefix"]
    if isinstance(teacher_model_dir, list):
        for tea_idx in range(len(teacher_model_dir)):
            model_filename = config["teacher_model_filename"][
                tea_idx] if "teacher_model_filename" in config else None
            params_filename = config["teacher_params_filename"][
                tea_idx] if "teacher_params_filename" in config else None
            if tea_idx == 0:
C
ceci3 已提交
298
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
299 300 301 302 303 304 305
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
306
                    distill_node_pair,
C
ceci3 已提交
307 308 309
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)
            else:
C
ceci3 已提交
310
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
311 312 313 314 315 316 317
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
318
                    distill_node_pair,
C
ceci3 已提交
319 320 321 322 323 324 325 326
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)

    else:
        model_filename = config[
            "teacher_model_filename"] if "teacher_model_filename" in config else None
        params_filename = config[
            "teacher_params_filename"] if "teacher_params_filename" in config else None
C
ceci3 已提交
327
        train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
328 329 330 331 332 333 334
            executor,
            place,
            train_program,
            config,
            teacher_model_dir,
            model_filename,
            params_filename,
C
ceci3 已提交
335
            distill_node_pair,
C
ceci3 已提交
336 337 338 339 340 341 342 343 344 345 346
            teacher_idx=None,
            feed_target_names=feed_target_names)
    # all feed node should set stop_gradient is False, for using pact quant algo.
    for var in train_program.list_vars():
        if var.name in data_name_map.values() or var.name in data_name_map.keys(
        ):
            var.stop_gradient = False

    train_fetch_list = []
    with paddle.static.program_guard(train_program, startup_program):
        with paddle.utils.unique_name.guard('merge'):
347
            optimizer, learning_rate = _create_optimizer(train_config)
C
ceci3 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379

            if train_config.get('use_fleet'):
                optimizer = fleet.distributed_optimizer(optimizer,
                                                        dist_strategy)
            else:
                if train_config.get('amp_config') is not None:
                    custom_white_list = train_config['amp_config'].get(
                        'custom_white_list', None)
                    if custom_white_list is not None:
                        train_config['amp_config'].pop('custom_white_list')

                    custom_black_list = train_config['amp_config'].get(
                        'custom_black_list', None)
                    if custom_black_list is not None:
                        train_config['amp_config'].pop('custom_black_list')

                    custom_black_varnames = train_config['amp_config'].get(
                        'custom_black_varnames', None)
                    if custom_black_varnames is not None:
                        train_config['amp_config'].pop('custom_black_varnames')

                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_white_list=custom_white_list,
                        custom_black_list=custom_black_list,
                        custom_black_varnames=custom_black_varnames)
                    optimizer = paddle.static.amp.decorate(
                        optimizer=optimizer,
                        amp_lists=amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        **train_config['amp_config'])

Z
zhouzj 已提交
380
            distill_loss, loss_dict = _parse_distill_loss(
C
ceci3 已提交
381 382 383
                distill_node_pair,
                config.get('loss') or 'l2',  ### default loss is l2
                config.get('alpha') or 1.0)  ### default alpha is 1.0
C
ceci3 已提交
384 385 386
            loss = paddle.mean(distill_loss)
            loss.stop_gradient = False

C
ceci3 已提交
387 388 389
            if 'prune_params_name' in config:  ### prune
                if 'pruned_ratio' not in config and not train_config.get(
                        'use_fleet'):  ### asp
C
ceci3 已提交
390 391 392 393 394 395 396 397 398 399 400
                    optimizer = pruner.decorate(optimizer)
                optimizer.minimize(loss)
            elif 'prune_strategy' in config:  ###unstructure prune
                optimizer.minimize(loss, no_grad_set=pruner.no_grad_set)
            else:
                optimizer.minimize(loss)

            train_fetch_list.append(loss)

    train_program_info = ProgramInfo(startup_program, train_program,
                                     feed_target_names, train_fetch_list,
Z
zhouzj 已提交
401
                                     optimizer, learning_rate, loss_dict)
C
ceci3 已提交
402 403 404 405 406 407 408 409 410 411 412
    test_program_info = ProgramInfo(startup_program, test_program,
                                    feed_target_names, fetch_targets)
    return train_program_info, test_program_info


def build_quant_program(executor, place, config, train_program_info,
                        test_program_info):
    scope = paddle.static.global_scope()

    assert isinstance(config, dict), "quant config must be dict"

C
ceci3 已提交
413
    use_pact = config.pop("use_pact")
C
ceci3 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
    if use_pact:
        act_preprocess_func = pact
        optimizer_func = get_pact_optimizer
        pact_executor = executor
    else:
        act_preprocess_func = None
        optimizer_func = None
        pact_executor = None

    test_program = quant_aware(
        test_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        for_test=True)

    train_program = quant_aware(
        train_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=pact_executor,
        for_test=False,
        return_program=True)

    train_program_info.program = train_program
    test_program_info.program = test_program
    return train_program_info, test_program_info, config


C
ceci3 已提交
449 450 451
def _get_label_info(dataloader, feed_target_names):
    label_info = {}
    for data in dataloader():
C
ceci3 已提交
452 453 454
        if isinstance(data, list) or isinstance(data, tuple):
            data = data[0]
        for key, value in data.items():
C
ceci3 已提交
455 456 457 458 459 460 461 462 463 464 465
            if key in feed_target_names:
                continue
            label_info['name'] = key
            label_info['dtype'] = np.array(value).dtype
            label_info['shape'] = list(np.array(value).shape)
            label_info['shape'][0] = -1
            break
        break
    return label_info


C
ceci3 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
def _get_chn_prune_params(program):
    params = []
    original_shapes = {}
    for block in program.blocks:
        for op in block.ops:
            if op.type == 'conv2d' and op.attr('groups') == 1:
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
                        original_shapes[inp_name] = var_.shape
    return params, original_shapes


def _get_asp_prune_params(program):
    params = []
    for block in program.blocks:
        for op in block.ops:
            if (op.type == 'conv2d' and op.attr('groups') == 1
                ) or op.type == 'mul' or op.type == 'matmul_v2':
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
    return params


C
ceci3 已提交
493 494 495 496 497 498 499
def build_prune_program(executor,
                        place,
                        config,
                        train_program_info,
                        strategy,
                        patterns,
                        eval_dataloader=None):
C
ceci3 已提交
500
    if strategy.startswith('unstructure'):
C
ceci3 已提交
501 502 503 504 505
        from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
        if config["prune_strategy"] is None:
            pruner = UnstructuredPruner(
                train_program_info.program,
                mode=config['prune_mode'],
C
ceci3 已提交
506
                ratio=config['ratio'],
C
ceci3 已提交
507 508 509 510 511 512 513
                threshold=config['threshold'],
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'], )
        elif config["prune_strategy"] == "gmp":
            pruner = GMPUnstructuredPruner(
                train_program_info.program,
C
ceci3 已提交
514
                ratio=config['ratio'],
C
ceci3 已提交
515 516 517
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'],
C
ceci3 已提交
518
                configs=config['gmp_config'])
C
ceci3 已提交
519 520 521
    elif strategy.startswith('channel_prune'):
        from ..prune import Pruner
        pruner = Pruner(config["criterion"])
C
ceci3 已提交
522 523 524 525 526 527 528 529
        if config['prune_params_name'] is None:
            params, original_shapes = _get_chn_prune_params(
                train_program_info.program)
        else:
            params = []
            original_shapes = {}
            for param in train_program_info.program.global_block(
            ).all_parameters():
Z
zhouzj 已提交
530
                if config['prune_params_name'] is not None and param.name in config['prune_params_name']:
C
ceci3 已提交
531 532
                    params.append(param.name)
                    original_shapes[param.name] = param.shape
C
ceci3 已提交
533

Z
zhouzj 已提交
534 535
        origin_flops = flops(train_program_info.program)

C
ceci3 已提交
536 537 538 539
        pruned_program, _, _ = pruner.prune(
            train_program_info.program,
            paddle.static.global_scope(),
            params=params,
Z
zhouzj 已提交
540 541
            ratios=[config['pruned_ratio']] * len(params) if isinstance(
                config['pruned_ratio'], float) else config['pruned_ratio'],
C
ceci3 已提交
542
            place=place)
W
whs 已提交
543 544
        _logger.info(
            "####################channel pruning##########################")
Z
zhouzj 已提交
545
        for param in pruned_program.global_block().all_parameters():
W
whs 已提交
546
            if param.name in original_shapes:
547 548
                _logger.info("{}, from {} to {}".format(
                    param.name, original_shapes[param.name], param.shape))
W
whs 已提交
549 550
        _logger.info(
            "####################channel pruning end##########################")
Z
zhouzj 已提交
551 552 553 554 555 556

        final_flops = flops(pruned_program)
        pruned_flops = abs(origin_flops - final_flops) / origin_flops
        _logger.info("FLOPs before pruning: {}".format(origin_flops))
        _logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
            final_flops, round(pruned_flops * 100, 2)))
C
ceci3 已提交
557 558 559
        train_program_info.program = pruned_program

    elif strategy.startswith('asp'):
560 561
        from paddle.incubate import asp
        pruner = asp
C
ceci3 已提交
562
        excluded_params_name = []
C
ceci3 已提交
563 564 565 566
        if config['prune_params_name'] is None:
            config['prune_params_name'] = _get_asp_prune_params(
                train_program_info.program)

C
ceci3 已提交
567
        for param in train_program_info.program.global_block().all_parameters():
M
minghaoBD 已提交
568 569 570 571 572
            if config['prune_params_name'] is not None:
                if param.name not in config['prune_params_name']:
                    excluded_params_name.append(param.name)
                else:
                    pruner.add_supported_layer(param.name)
C
ceci3 已提交
573 574
            if "teacher_" in param.name:
                excluded_params_name.append(param.name)
Z
zhouzj 已提交
575 576 577
        pruner.set_excluded_layers(
            main_program=train_program_info.program,
            param_names=excluded_params_name)
C
ceci3 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
    elif strategy.startswith('transformer_prune'):
        from .transformer_pruner import TransformerPruner
        assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
        label_info = _get_label_info(eval_dataloader,
                                     train_program_info.feed_target_names)
        assert len(label_info) != 0, \
            "maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
        pruner = TransformerPruner(
            executor,
            place,
            train_program_info.program,
            patterns,
            label_info,
            width_mult=(1.0 - config['pruned_ratio']),
            dataloader=eval_dataloader,
            fetch_targets=train_program_info.fetch_targets)
        pruned_program = pruner.prune()
        train_program_info.program = pruned_program
C
ceci3 已提交
596
    else:
C
ceci3 已提交
597 598 599
        raise NotImplementedError(
            "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
            format(config['prune_algo']))
C
ceci3 已提交
600 601

    return pruner, train_program_info
Z
zhouzj 已提交
602 603 604 605 606 607 608 609 610 611


def remove_unused_var_nodes(program):
    '''
    This function is called before saving the sparse model to remove redundant nodes.
    Args:
        program(paddle.static.Program): The sparse model to be saved.
    Returns:
        program(paddle.static.Program): The sparse model.
    '''
C
ceci3 已提交
612
    from paddle.framework import core
Z
zhouzj 已提交
613 614 615 616 617 618 619 620 621 622 623
    from paddle.fluid.framework import IrGraph
    graph = IrGraph(core.Graph(program.desc), for_test=True)
    removed_nodes = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            if '_mask' in input_node.name():
                removed_nodes.add(op_node)
    graph.safe_remove_nodes(removed_nodes)
    program = graph.to_program()
    return program