create_compressed_program.py 23.6 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
C
ceci3 已提交
16
import numpy as np
C
ceci3 已提交
17 18 19
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
C
ceci3 已提交
20
import paddle.regularizer as regularizer
C
ceci3 已提交
21 22 23 24 25
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node
from ..common import get_logger
from .strategy_config import ProgramInfo
26
from ..common.load_model import load_inference_model
Z
zhouzj 已提交
27
from ..analysis import flops
C
ceci3 已提交
28 29 30

_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
Z
zhouzj 已提交
31 32
    'build_distill_program', 'build_quant_program', 'build_prune_program',
    'remove_unused_var_nodes'
C
ceci3 已提交
33 34 35
]


36 37 38 39 40 41 42 43 44 45 46 47
def _create_lr_scheduler(train_config):
    if 'learning_rate' not in train_config:
        raise RuntimeError(
            'No `learning_rate` specified in the configuration file.')
    if isinstance(train_config.get('learning_rate'), float):
        return train_config.get('learning_rate')

    params = train_config.get('learning_rate')
    lr_type = params.pop('type')
    return getattr(optimizer.lr, lr_type)(**params)


C
ceci3 已提交
48 49
def _create_optimizer(train_config):
    """create optimizer"""
C
ceci3 已提交
50 51 52 53
    if 'optimizer_builder' not in train_config:
        train_config['optimizer_builder'] = {'optimizer': {'type': 'SGD'}}

    optimizer_builder = train_config['optimizer_builder']
W
whs 已提交
54 55
    assert isinstance(
        optimizer_builder, dict
56 57
    ), "Value of 'optimizer_builder' in train_config should be dict but got {}".format(
        type(optimizer_builder))
C
ceci3 已提交
58 59 60 61
    if 'grad_clip' in optimizer_builder:
        g_clip_params = optimizer_builder['grad_clip']
        g_clip_type = g_clip_params.pop('type')
        grad_clip = getattr(paddle.nn, g_clip_type)(**g_clip_params)
C
ceci3 已提交
62 63 64
    else:
        grad_clip = None

C
ceci3 已提交
65 66 67 68 69 70 71 72 73 74 75
    ### build regularization
    if 'regularizer' in optimizer_builder:
        reg_params = optimizer_builder['regularizer']
        reg_type = reg_params.pop('type')
        reg = getattr(regularizer, reg_type)(**reg_params)
    elif 'weight_decay' in optimizer_builder:
        reg = optimizer_builder.pop('weight_decay')
    else:
        reg = None

    ### build learning rate
76
    lr = _create_lr_scheduler(train_config)
C
ceci3 已提交
77 78 79 80 81 82 83 84 85 86 87

    ### build optimizer
    optim_params = optimizer_builder['optimizer']
    optim_type = optim_params.pop('type')
    opt = getattr(optimizer, optim_type)(learning_rate=lr,
                                         grad_clip=grad_clip,
                                         weight_decay=reg,
                                         **optim_params)
    return opt, lr


C
ceci3 已提交
88 89 90 91 92 93 94
def _find_var_from_program(program, var_name):
    for block in program.blocks:
        if block.has_var(var_name):
            return block.var(var_name)
    raise ValueError("var {} not in this program".format(var_name))


C
ceci3 已提交
95 96 97 98 99 100 101 102 103 104 105
def _get_distill_node(student_program, config):
    node = config.get('node')
    if len(node) == 0:
        return None

    ### the type of node is list or list(list)
    if isinstance(node[0], list):
        test_node = node[0][0]
    else:
        test_node = node[0]
    try:
C
ceci3 已提交
106
        test_var = _find_var_from_program(student_program, test_node)
C
ceci3 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
        distill_node_pair = []
        if isinstance(node[0], list):
            for n_list in node:
                tmp_node_pair = []
                for n in n_list:
                    tmp_node_pair.append('teacher_' + n)
                    tmp_node_pair.append(n)
                distill_node_pair.append(tmp_node_pair)
        else:
            for n in node:
                distill_node_pair.append('teacher_' + n)
                distill_node_pair.append(n)
        return distill_node_pair
    except:
        return node
C
ceci3 已提交
122 123


124 125 126 127 128 129 130 131 132 133
def _get_target_node(distill_node, teacher=False):
    tmp_nodes = set()
    if isinstance(distill_node[0], list):
        for n_list in distill_node:
            for n in n_list:
                tmp_nodes.add(n)
    else:
        for n in distill_node:
            tmp_nodes.add(n)

C
ceci3 已提交
134
    targets = []
135 136 137 138 139
    for node in tmp_nodes:
        if teacher and 'teacher_' in node:
            tmp = node.split('teacher_')[-1]
            targets.append(tmp)
        if not teacher and 'teacher_' not in node:
C
ceci3 已提交
140
            targets.append(node)
141

C
ceci3 已提交
142 143 144
    return targets


C
ceci3 已提交
145
def _parse_distill_loss(distill_node_pair,
C
ceci3 已提交
146
                        distill_loss='l2',
C
ceci3 已提交
147 148 149
                        distill_lambda=1.0):
    """parse distill loss config"""
    loss_dist = 0.0
Z
zhouzj 已提交
150
    losses = {}
C
ceci3 已提交
151 152 153 154 155 156 157 158 159
    if isinstance(distill_node_pair[0], str):
        assert isinstance(distill_loss, str)
        assert isinstance(distill_lambda, float)
        distill_node_pair = [distill_node_pair]
        distill_loss = [distill_loss]
        distill_lambda = [distill_lambda]

    assert len(distill_node_pair) == len(distill_loss)
    assert len(distill_node_pair) == len(distill_lambda)
Z
zhouzj 已提交
160 161 162 163 164
    for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
                                    distill_lambda):
        tmp_loss = losses.get(loss_clas, 0.0)
        _logger.info("train config.distill_node_pair: {}".format(
            node, loss_clas, lam))
C
ceci3 已提交
165 166 167
        assert len(node) % 2 == 0, \
            "distill_node_pair config wrong, the length needs to be an even number"
        for i in range(len(node) // 2):
Z
zhouzj 已提交
168 169 170
            tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
        loss_dist += tmp_loss
        losses[loss_clas] = tmp_loss
C
ceci3 已提交
171 172 173 174 175 176 177 178 179 180 181

    return loss_dist, losses


def _load_program_and_merge(executor,
                            place,
                            train_program,
                            config,
                            model_dir,
                            model_filename,
                            params_filename,
C
ceci3 已提交
182
                            distill_node_pair,
C
ceci3 已提交
183 184
                            teacher_idx=None,
                            feed_target_names=None):
C
Chang Xu 已提交
185 186
    scope = paddle.static.global_scope()
    new_scope = paddle.static.Scope()
C
ceci3 已提交
187

C
ceci3 已提交
188 189
    if params_filename == 'None':
        params_filename = None
C
ceci3 已提交
190 191 192 193 194 195 196 197 198 199 200 201

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )

    with paddle.static.scope_guard(new_scope):
        [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= (load_inference_model( \
            model_dir, \
            model_filename=model_filename, \
            params_filename=params_filename, \
            executor=executor))
C
ceci3 已提交
202 203 204

    _remove_fetch_node(teacher_program)

205
    target_nodes = _get_target_node(distill_node_pair, True)
C
ceci3 已提交
206
    teacher_program = teacher_program._prune(target_nodes)
C
ceci3 已提交
207 208 209

    data_name_map = {}

C
ceci3 已提交
210 211 212
    merge_feed = (
        sorted(feed_target_names) == sorted(teacher_feed_target_names))
    if merge_feed == True:
C
ceci3 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225
        for i, name in enumerate(feed_target_names):
            data_name_map[teacher_feed_target_names[i]] = name

    if teacher_idx is None:
        teacher_name_prefix = 'teacher_'
    else:
        teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx))

    merge(
        teacher_program,
        train_program,
        data_name_map,
        place,
C
Chang Xu 已提交
226
        teacher_scope=new_scope,
C
ceci3 已提交
227
        name_prefix=teacher_name_prefix,
C
ceci3 已提交
228
        merge_feed=merge_feed)
C
ceci3 已提交
229
    if teacher_idx == None or teacher_idx == 1:
C
ceci3 已提交
230
        return train_program, data_name_map
C
ceci3 已提交
231
    else:
C
ceci3 已提交
232
        return train_program, data_name_map
C
ceci3 已提交
233 234 235 236 237 238 239 240


def build_distill_program(executor,
                          place,
                          config,
                          train_config,
                          train_program_info=None,
                          pruner=None,
C
ceci3 已提交
241 242
                          dist_strategy=None,
                          default_distill_node_pair=None):
C
ceci3 已提交
243 244 245
    """build distill program with infermodel"""
    startup_program = paddle.static.Program()
    if train_program_info is None:
C
ceci3 已提交
246
        [train_program, feed_target_names, fetch_targets]= (load_inference_model( \
C
ceci3 已提交
247
            path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \
C
ceci3 已提交
248
            executor=executor))
C
ceci3 已提交
249 250 251 252 253 254
        train_program = recover_inference_program(train_program)
    else:
        train_program = train_program_info.program
        feed_target_names = train_program_info.feed_target_names
        fetch_targets = train_program_info.fetch_targets

C
ceci3 已提交
255 256 257
    distill_node_pair = _get_distill_node(train_program,
                                          config) or default_distill_node_pair

C
ceci3 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    test_program = train_program.clone(for_test=True)

    target_nodes = _get_target_node(distill_node_pair)

    def _prepend_feed(block, feed_idx, feed_target_names):
        for idx in feed_idx[::-1]:
            block._remove_op(idx)

        feed_var = block.create_var(
            name='feed',
            type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH,
            persistable=True, )

        for i, name in enumerate(feed_target_names):
            out = block.var(name)
            block._prepend_op(
                type='feed',
                inputs={'X': [feed_var]},
                outputs={'Out': [out]},
                attrs={'col': i})

    judge_feed_pos = False
    if train_program.desc.block(0).op(0).type() != 'feed':
        judge_feed_pos = True
    if judge_feed_pos:
        feed_idx = []
        for op in train_program.global_block().ops:
            if op.type == 'feed':
                feed_idx.append(op.idx)
        _prepend_feed(train_program.global_block(), feed_idx, feed_target_names)
    train_program = train_program._prune(target_nodes)

C
ceci3 已提交
290 291 292 293 294 295 296 297 298 299
    teacher_model_dir = config[
        "teacher_model_dir"] if "teacher_model_dir" in config else config[
            "teacher_model_path_prefix"]
    if isinstance(teacher_model_dir, list):
        for tea_idx in range(len(teacher_model_dir)):
            model_filename = config["teacher_model_filename"][
                tea_idx] if "teacher_model_filename" in config else None
            params_filename = config["teacher_params_filename"][
                tea_idx] if "teacher_params_filename" in config else None
            if tea_idx == 0:
C
ceci3 已提交
300
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
301 302 303 304 305 306 307
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
308
                    distill_node_pair,
C
ceci3 已提交
309 310 311
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)
            else:
C
ceci3 已提交
312
                train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
313 314 315 316 317 318 319
                    executor,
                    place,
                    train_program,
                    config,
                    teacher_model_dir[tea_idx],
                    model_filename,
                    params_filename,
C
ceci3 已提交
320
                    distill_node_pair,
C
ceci3 已提交
321 322 323 324 325 326 327 328
                    teacher_idx=(tea_idx + 1),
                    feed_target_names=feed_target_names)

    else:
        model_filename = config[
            "teacher_model_filename"] if "teacher_model_filename" in config else None
        params_filename = config[
            "teacher_params_filename"] if "teacher_params_filename" in config else None
C
ceci3 已提交
329
        train_program, data_name_map = _load_program_and_merge(
C
ceci3 已提交
330 331 332 333 334 335 336
            executor,
            place,
            train_program,
            config,
            teacher_model_dir,
            model_filename,
            params_filename,
C
ceci3 已提交
337
            distill_node_pair,
C
ceci3 已提交
338 339 340 341 342 343 344 345 346 347 348
            teacher_idx=None,
            feed_target_names=feed_target_names)
    # all feed node should set stop_gradient is False, for using pact quant algo.
    for var in train_program.list_vars():
        if var.name in data_name_map.values() or var.name in data_name_map.keys(
        ):
            var.stop_gradient = False

    train_fetch_list = []
    with paddle.static.program_guard(train_program, startup_program):
        with paddle.utils.unique_name.guard('merge'):
349
            optimizer, learning_rate = _create_optimizer(train_config)
C
ceci3 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381

            if train_config.get('use_fleet'):
                optimizer = fleet.distributed_optimizer(optimizer,
                                                        dist_strategy)
            else:
                if train_config.get('amp_config') is not None:
                    custom_white_list = train_config['amp_config'].get(
                        'custom_white_list', None)
                    if custom_white_list is not None:
                        train_config['amp_config'].pop('custom_white_list')

                    custom_black_list = train_config['amp_config'].get(
                        'custom_black_list', None)
                    if custom_black_list is not None:
                        train_config['amp_config'].pop('custom_black_list')

                    custom_black_varnames = train_config['amp_config'].get(
                        'custom_black_varnames', None)
                    if custom_black_varnames is not None:
                        train_config['amp_config'].pop('custom_black_varnames')

                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_white_list=custom_white_list,
                        custom_black_list=custom_black_list,
                        custom_black_varnames=custom_black_varnames)
                    optimizer = paddle.static.amp.decorate(
                        optimizer=optimizer,
                        amp_lists=amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        **train_config['amp_config'])

Z
zhouzj 已提交
382
            distill_loss, loss_dict = _parse_distill_loss(
C
ceci3 已提交
383 384 385
                distill_node_pair,
                config.get('loss') or 'l2',  ### default loss is l2
                config.get('alpha') or 1.0)  ### default alpha is 1.0
C
ceci3 已提交
386 387 388
            loss = paddle.mean(distill_loss)
            loss.stop_gradient = False

C
ceci3 已提交
389 390 391
            if 'prune_params_name' in config:  ### prune
                if 'pruned_ratio' not in config and not train_config.get(
                        'use_fleet'):  ### asp
C
ceci3 已提交
392 393 394 395 396 397 398 399 400 401 402
                    optimizer = pruner.decorate(optimizer)
                optimizer.minimize(loss)
            elif 'prune_strategy' in config:  ###unstructure prune
                optimizer.minimize(loss, no_grad_set=pruner.no_grad_set)
            else:
                optimizer.minimize(loss)

            train_fetch_list.append(loss)

    train_program_info = ProgramInfo(startup_program, train_program,
                                     feed_target_names, train_fetch_list,
Z
zhouzj 已提交
403
                                     optimizer, learning_rate, loss_dict)
C
ceci3 已提交
404 405 406 407 408 409 410 411 412 413 414
    test_program_info = ProgramInfo(startup_program, test_program,
                                    feed_target_names, fetch_targets)
    return train_program_info, test_program_info


def build_quant_program(executor, place, config, train_program_info,
                        test_program_info):
    scope = paddle.static.global_scope()

    assert isinstance(config, dict), "quant config must be dict"

C
ceci3 已提交
415
    use_pact = config.pop("use_pact")
C
ceci3 已提交
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    if use_pact:
        act_preprocess_func = pact
        optimizer_func = get_pact_optimizer
        pact_executor = executor
    else:
        act_preprocess_func = None
        optimizer_func = None
        pact_executor = None

    test_program = quant_aware(
        test_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=None,
        optimizer_func=None,
        executor=None,
        for_test=True)

    train_program = quant_aware(
        train_program_info.program,
        place,
        config,
        scope=scope,
        act_preprocess_func=act_preprocess_func,
        optimizer_func=optimizer_func,
        executor=pact_executor,
        for_test=False,
        return_program=True)

    train_program_info.program = train_program
    test_program_info.program = test_program
    return train_program_info, test_program_info, config


C
ceci3 已提交
451 452 453
def _get_label_info(dataloader, feed_target_names):
    label_info = {}
    for data in dataloader():
C
ceci3 已提交
454 455 456
        if isinstance(data, list) or isinstance(data, tuple):
            data = data[0]
        for key, value in data.items():
C
ceci3 已提交
457 458 459 460 461 462 463 464 465 466 467
            if key in feed_target_names:
                continue
            label_info['name'] = key
            label_info['dtype'] = np.array(value).dtype
            label_info['shape'] = list(np.array(value).shape)
            label_info['shape'][0] = -1
            break
        break
    return label_info


C
ceci3 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
def _get_chn_prune_params(program):
    params = []
    original_shapes = {}
    for block in program.blocks:
        for op in block.ops:
            if op.type == 'conv2d' and op.attr('groups') == 1:
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
                        original_shapes[inp_name] = var_.shape
    return params, original_shapes


def _get_asp_prune_params(program):
    params = []
    for block in program.blocks:
        for op in block.ops:
            if (op.type == 'conv2d' and op.attr('groups') == 1
                ) or op.type == 'mul' or op.type == 'matmul_v2':
                for inp_name in op.input_arg_names:
                    var_ = block.var(inp_name)
                    if var_.persistable is True:
                        params.append(inp_name)
    return params


C
ceci3 已提交
495 496 497 498 499 500 501
def build_prune_program(executor,
                        place,
                        config,
                        train_program_info,
                        strategy,
                        patterns,
                        eval_dataloader=None):
C
ceci3 已提交
502
    if strategy.startswith('unstructure'):
C
ceci3 已提交
503 504 505 506 507
        from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
        if config["prune_strategy"] is None:
            pruner = UnstructuredPruner(
                train_program_info.program,
                mode=config['prune_mode'],
C
ceci3 已提交
508
                ratio=config['ratio'],
C
ceci3 已提交
509 510 511 512 513 514 515
                threshold=config['threshold'],
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'], )
        elif config["prune_strategy"] == "gmp":
            pruner = GMPUnstructuredPruner(
                train_program_info.program,
C
ceci3 已提交
516
                ratio=config['ratio'],
C
ceci3 已提交
517 518 519
                prune_params_type=config['prune_params_type'],
                place=place,
                local_sparsity=config['local_sparsity'],
C
ceci3 已提交
520
                configs=config['gmp_config'])
C
ceci3 已提交
521 522 523
    elif strategy.startswith('channel_prune'):
        from ..prune import Pruner
        pruner = Pruner(config["criterion"])
C
ceci3 已提交
524 525 526 527 528 529 530 531 532 533 534 535 536
        if config['prune_params_name'] is None:
            params, original_shapes = _get_chn_prune_params(
                train_program_info.program)
        else:
            params = []
            original_shapes = {}
            for param in train_program_info.program.global_block(
            ).all_parameters():
                if config[
                        'prune_params_name'] is not None and param.name in config[
                            'prune_params_name']:
                    params.append(param.name)
                    original_shapes[param.name] = param.shape
C
ceci3 已提交
537

Z
zhouzj 已提交
538 539
        origin_flops = flops(train_program_info.program)

C
ceci3 已提交
540 541 542 543
        pruned_program, _, _ = pruner.prune(
            train_program_info.program,
            paddle.static.global_scope(),
            params=params,
C
ceci3 已提交
544 545 546
            ratios=[config['pruned_ratio']] * len(params)
            if isinstance(config['pruned_ratio'], float) else
            config['pruned_ratio'],
C
ceci3 已提交
547
            place=place)
W
whs 已提交
548 549
        _logger.info(
            "####################channel pruning##########################")
Z
zhouzj 已提交
550
        for param in pruned_program.global_block().all_parameters():
W
whs 已提交
551
            if param.name in original_shapes:
552 553
                _logger.info("{}, from {} to {}".format(
                    param.name, original_shapes[param.name], param.shape))
W
whs 已提交
554 555
        _logger.info(
            "####################channel pruning end##########################")
Z
zhouzj 已提交
556 557 558 559 560 561

        final_flops = flops(pruned_program)
        pruned_flops = abs(origin_flops - final_flops) / origin_flops
        _logger.info("FLOPs before pruning: {}".format(origin_flops))
        _logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
            final_flops, round(pruned_flops * 100, 2)))
C
ceci3 已提交
562 563 564
        train_program_info.program = pruned_program

    elif strategy.startswith('asp'):
565 566
        from paddle.incubate import asp
        pruner = asp
C
ceci3 已提交
567
        excluded_params_name = []
C
ceci3 已提交
568 569 570 571
        if config['prune_params_name'] is None:
            config['prune_params_name'] = _get_asp_prune_params(
                train_program_info.program)

C
ceci3 已提交
572
        for param in train_program_info.program.global_block().all_parameters():
M
minghaoBD 已提交
573 574 575 576 577
            if config['prune_params_name'] is not None:
                if param.name not in config['prune_params_name']:
                    excluded_params_name.append(param.name)
                else:
                    pruner.add_supported_layer(param.name)
C
ceci3 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
            if "teacher_" in param.name:
                excluded_params_name.append(param.name)
        pruner.set_excluded_layers(train_program_info.program,
                                   excluded_params_name)
    elif strategy.startswith('transformer_prune'):
        from .transformer_pruner import TransformerPruner
        assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
        label_info = _get_label_info(eval_dataloader,
                                     train_program_info.feed_target_names)
        assert len(label_info) != 0, \
            "maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
        pruner = TransformerPruner(
            executor,
            place,
            train_program_info.program,
            patterns,
            label_info,
            width_mult=(1.0 - config['pruned_ratio']),
            dataloader=eval_dataloader,
            fetch_targets=train_program_info.fetch_targets)
        pruned_program = pruner.prune()
        train_program_info.program = pruned_program
C
ceci3 已提交
600
    else:
C
ceci3 已提交
601 602 603
        raise NotImplementedError(
            "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
            format(config['prune_algo']))
C
ceci3 已提交
604 605

    return pruner, train_program_info
Z
zhouzj 已提交
606 607 608 609 610 611 612 613 614 615


def remove_unused_var_nodes(program):
    '''
    This function is called before saving the sparse model to remove redundant nodes.
    Args:
        program(paddle.static.Program): The sparse model to be saved.
    Returns:
        program(paddle.static.Program): The sparse model.
    '''
C
ceci3 已提交
616
    from paddle.framework import core
Z
zhouzj 已提交
617 618 619 620 621 622 623 624 625 626 627
    from paddle.fluid.framework import IrGraph
    graph = IrGraph(core.Graph(program.desc), for_test=True)
    removed_nodes = set()
    ops = graph.all_op_nodes()
    for op_node in ops:
        for input_node in op_node.inputs:
            if '_mask' in input_node.name():
                removed_nodes.add(op_node)
    graph.safe_remove_nodes(removed_nodes)
    program = graph.to_program()
    return program