asp.py 41.9 KB
Newer Older
1 2
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.
3
#
4 5 6
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Functions for Auto SParsity (ASP) training and inference.
"""

M
minghaoBD 已提交
19
import os
20 21 22
import copy
import numpy as np
import paddle
23
from paddle.fluid.framework import dygraph_only
24
from paddle.fluid import global_scope, program_guard, layers
25 26
from paddle.fluid.initializer import ConstantInitializer
from paddle.fluid.contrib import sparsity
27
from paddle.fluid import core
28 29 30
from paddle.fluid.contrib.sparsity.supported_layer_list import (
    supported_layers_and_prune_func_map,
)
31
from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning
M
minghaoBD 已提交
32 33 34

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
35 36

__all__ = [
37 38 39 40
    'decorate',
    'prune_model',
    'set_excluded_layers',
    'reset_excluded_layers',
41 42 43
]


44
def set_excluded_layers(param_names, main_program=None):
45 46 47 48
    r"""
    Set parameter name of layers which would not be pruned as sparse weights.

    Args:
49
        param_names (list of string): A list contains names of parameters.
50
        main_program (Program, optional): Program with model definition and its parameters.
51 52
                                          If None is given, then it would be set as `paddle.static.default_main_program().
                                          Default is None.
53
    Examples:
54 55 56 57 58 59 60 61
        1. Usage of Dynamic Graph

            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
62
                        super().__init__()
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

                # Need to set excluded layers before calling decorate
                paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])

                optimizer = paddle.incubate.asp.decorate(optimizer)

        2. Usage of Static Graph

            .. code-block:: python

                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
92
                        super().__init__()
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    # Setup exluded layers out from ASP workflow.
                    # Please note, excluded_layers must be set before calling optimizer.minimize().
                    paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    optimizer = paddle.static.amp.decorate(optimizer )
119
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
120 121 122
                    # will insert necessary masking operations for ASP workflow.
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
123
    """
124 125
    if main_program is None:
        main_program = paddle.static.default_main_program()
126 127 128
    ASPHelper.set_excluded_layers(
        param_names=param_names, main_program=main_program
    )
129 130 131 132


def reset_excluded_layers(main_program=None):
    r"""
133
    Reset exculded layers setting corresponding to :attr:`main_program`. If :attr:`main_program`
134 135 136 137
    is None, then all configurations of excluded_layers would be cleaned.

    Args:
        main_program (Program, optional): Program with model definition and its parameters.
138 139 140 141
                                          If None is given, then this function would reset all excluded_layers.
                                          Default is None.
    Examples:
        1. Usage of Dynamic Graph
142

143
            .. code-block:: python
144

145
                import paddle
146

147 148
                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
149
                        super().__init__()
150 151 152
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)
153

154 155 156 157 158
                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction
159

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

                # Need to set excluded layers before calling decorate
                paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])
                # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
                # Please note, reset_excluded_layers also must be called before calling sparsity.decorate().
                paddle.incubate.asp.reset_excluded_layers()

                optimizer = paddle.incubate.asp.decorate(optimizer)

        2. Usage of Static Graph

            .. code-block:: python
175

176 177 178 179 180 181
                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
182
                        super().__init__()
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    # Setup exluded layers out from ASP workflow.
                    # Please note, excluded_layers must be set before calling optimizer.minimize().
                    paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)
                    # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
                    # Please note, reset_excluded_layers also must be called before calling optimizer.minimize().
                    paddle.incubate.asp.reset_excluded_layers(main_program)

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    optimizer = paddle.static.amp.decorate(optimizer )
212
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
213 214 215
                    # will insert necessary masking operations for ASP workflow.
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
216 217 218 219 220 221
    """
    ASPHelper.reset_excluded_layers(main_program=main_program)


def decorate(optimizer):
    r"""
222 223
    Wrap the given optimizer as a OptimizerWithSparsityGuarantee,
    If runnig with dynamic graph mode. ASP would creates mask variables for supported parameters.
224
    Else if in static graph mode, ASP would creates mask variables and inserts necessary ops
225
    when calling minimize()
226 227 228 229 230 231

    Args:
        optimizer (Optimizer): A Optimizer used for training.
    Returns:
        OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer.
    Examples:
232
        1. Usage of Dynamic Graph
233

234 235 236 237 238 239
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
240
                        super().__init__()
241 242 243 244 245
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)
246

247 248 249 250 251 252 253
                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction
254

255 256 257
                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())
258

259
                # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which
260 261 262
                # will apply necessary masking operations for ASP workflow.
                # In dynamic graph mode, ASP would create related mask variables during decoration.
                optimizer = paddle.incubate.asp.decorate(optimizer)
263

264
        2. Usage of Static Graph
265

266 267 268 269 270 271 272 273
            .. code-block:: python

                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
274
                        super().__init__()
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
296
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
297
                    # will insert necessary masking operations for ASP workflow.
298
                    # In static graph mode, ASP creates related mask variables
299 300 301
                    # during minimize().
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
302 303 304 305
    """
    return ASPHelper.decorate(optimizer)


306
def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True):
307
    r"""
308 309
    Pruning parameters of supported layers in :attr:`model` via
    specified mask generation function given by :attr:`mask_algo`. This
310 311 312 313
    function supports both training and inference controlled by :attr:`with_mask`.
    If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables,
    else only prunes parameters.

314 315 316
    *Note*: (Static graph mode) If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize`
    and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable).
    Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for
317
    inference only. To obtain OptimizerWithSparsityGuarantee, please see `paddle.incubate.asp.decoreate()`.
318 319

    Args:
320 321 322
        model (Program|nn.Layer): Program with model definition and its parameters, or a object of `paddle.nn.Layer`.
        n (int, optional): n of `n:m` sparse pattern. Default is 2.
        m (int, optional): m of `n:m` sparse pattern. Default is 4.
323 324
        mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`.
                                      The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'.
325
        with_mask (bool, optional): To prune mask Variables related to parameters or not. True is purning also, False is not. Default is True.
326 327 328
    Returns:
        dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable.
    Examples:
329
        1. Usage of Dynamic Graph
330

331
            .. code-block:: python
332

333 334 335 336 337
                import paddle
                import numpy as np

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
338
                        super().__init__()
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction

                my_layer = MyLayer()
                loss_fn = paddle.nn.MSELoss(reduction='mean')

                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

359
                # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
                # will apply necessary masking operations for ASP workflow.
                # In dynamic graph mode, ASP would create related mask variables during decoration.
                optimizer = paddle.incubate.asp.decorate(optimizer)

                # Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model()
                paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')

                for i in range(10):
                    imgs = paddle.to_tensor(
                        np.random.randn(64, 3, 32, 32),
                        dtype='float32', stop_gradient=False)
                    labels = paddle.to_tensor(
                        np.random.randint(10, size=(64, 1)),
                        dtype='float32', stop_gradient=False)
                    output = my_layer(imgs)
                    loss = loss_fn(output, labels)
                    loss.backward()
                    optimizer.step()
                    optimizer.clear_grad()

        2. Usage of Static Graph
381

382
            .. code-block:: python
383

384 385 386 387 388 389 390
                import paddle
                import numpy as np

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
391
                        super().__init__()
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32])
                    label = paddle.static.data(name='label', shape=[None, 1])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
417
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which
418
                    # will insert necessary masking operations for ASP workflow.
419
                    # In static graph mode, ASP creates related mask variables
420 421 422 423 424 425 426 427 428 429 430 431
                    # during minimize().
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)

                device = paddle.device.get_device()
                place = paddle.set_device(device)

                exe = paddle.static.Executor(place)
                exe.run(startup_program)

                # Must call exe.run(startup_program) first before calling paddle.asp.prune_model()
                paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')
432
                # it also be accepted to call
433 434 435 436 437 438
                # paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best')

                for i in range(10):
                    imgs = np.random.randn(64, 3, 32, 32).astype('float32')
                    labels = np.random.randint(10, size=(64, 1)).astype('float32')
                    exe.run(main_program, feed={'data':imgs, 'label':labels})
439
    """
440 441
    device = paddle.device.get_device()
    place = paddle.set_device(device)
442 443 444 445

    MaskAlgo_mapping = {
        'mask_1d': sparsity.MaskAlgo.MASK_1D,
        'mask_2d_greedy': sparsity.MaskAlgo.MASK_2D_GREEDY,
446
        'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST,
447
    }
448 449 450
    assert (
        mask_algo in MaskAlgo_mapping
    ), 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]'
451 452 453 454 455 456

    prune_func = None
    if isinstance(model, paddle.nn.Layer):
        prune_func = ASPHelper.prune_model_by_layer
    elif isinstance(model, paddle.static.Program):
        prune_func = ASPHelper.prune_model_by_program
457 458 459 460 461
        if (
            hasattr(model, "distributed_info_")
            and model.distributed_info_["sharding_degree"] > 1
            and paddle.fluid.is_compiled_with_cuda()
        ):
462 463 464 465
            gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
            place = paddle.CUDAPlace(gpu_id)
    else:
        raise TypeError(
466 467 468 469
            "model should be paddle.nn.Layer or paddle.static.Program, but got {}".format(
                type(model)
            )
        )
470

471 472 473 474 475 476 477 478
    return prune_func(
        place,
        model,
        n=n,
        m=m,
        mask_algo=MaskAlgo_mapping[mask_algo],
        with_mask=with_mask,
    )
479 480


481
class ProgramASPInfo:
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
    r"""
    ProgramASPInfo is a container to keep ASP relevant information of Pragrom. It contains three inner-variables:
    1. __mask_vars (Dictionary): Key is parameter's name and vaule is its corresponding sparse mask Variable object, which is created by `ASPHelper.create_mask_variables`.
    2. __masks (Dictionary): Key is parameter's name and vaule is its corressponding sparse mask Numpy array, which is created by `ASPHelper.prune_model`.
    3. __excluded_layers (List): It stores name of layers which should not involve into ASP workflow.
    """

    def __init__(self):
        self.__mask_vars = {}
        self.__masks = {}
        self.__excluded_layers = []

    def update_mask_vars(self, param_name, var):
        self.__mask_vars[param_name] = var

    def update_masks(self, param_name, var):
        self.__masks[param_name] = var

    def update_excluded_layers(self, param_names):
        self.__excluded_layers.extend(copy.deepcopy(param_names))

    def reset_excluded_layers(self):
        self.__excluded_layers = []

    @property
    def mask_vars(self):
        return self.__mask_vars

    @property
    def masks(self):
        return self.__masks

    @property
    def excluded_layers(self):
        return self.__excluded_layers


519
class ASPHelper:
520
    r"""
521
    ASPHelper is a collection of Auto SParsity (ASP) functions to enable
522 523 524 525 526

    1. training models with weights in 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 from scratch.
    2. pruning well-trained models into 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 for fine-tuning.
    """

527 528
    MASK_APPENDDED_NAME = 'asp_mask'
    PADDLE_WEIGHT_SUFFIX = "w_"
529 530 531 532

    __asp_info = {}

    @classmethod
533
    def set_excluded_layers(cls, param_names, main_program):
534 535 536 537 538 539 540 541 542 543 544 545
        r"""
        This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`.
        """
        asp_info = cls._get_program_asp_info(main_program)
        asp_info.update_excluded_layers(param_names)

    @classmethod
    def reset_excluded_layers(cls, main_program=None):
        r"""
        This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`.
        """
        if main_program is None:
546 547
            for prog in cls.__asp_info:
                cls.__asp_info[prog].reset_excluded_layers()
548 549 550 551 552 553 554 555
        else:
            cls._get_program_asp_info(main_program).reset_excluded_layers()

    @staticmethod
    def decorate(optimizer):
        r"""
        This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`.
        """
556 557 558 559 560 561 562
        if paddle.in_dynamic_mode():
            # main_prog and startup_prog would be used with paddle.static.program_guard
            # to create ASP masks. Moreover, main_prog is a key to map paddle.static.Program
            # to its own ASP informantion, like ASP mask variables. For dynamic graph, we use
            # default_main_program as the key.
            main_prog = paddle.static.default_main_program()
            startup_prog = paddle.static.default_startup_program()
563 564 565
            ASPHelper._create_mask_variables(
                main_prog, startup_prog, optimizer._parameter_list
            )
566 567 568
        return OptimizerWithSparsityGuarantee(optimizer)

    @classmethod
569 570 571 572 573 574 575 576 577
    def prune_model_by_program(
        cls,
        place,
        main_program=None,
        n=2,
        m=4,
        mask_algo=sparsity.MaskAlgo.MASK_1D,
        with_mask=True,
    ):
578 579 580 581 582 583 584 585 586 587 588 589 590
        r"""
        This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
        """

        if main_program is None:
            main_program = paddle.static.default_main_program()

        asp_info = cls._get_program_asp_info(main_program)
        for param in main_program.global_block().all_parameters():
            if ASPHelper._is_supported_layer(main_program, param.name):
                weight_tensor = global_scope().find_var(param.name).get_tensor()
                weight_nparray = np.array(weight_tensor)

591 592
                prune_func = ASPHelper._get_prune_func_by_name(param.name)

593 594 595
                weight_pruned_nparray, weight_sparse_mask = prune_func(
                    weight_nparray, m, n, mask_algo, param.name
                )
596
                weight_pruned_nparray = weight_pruned_nparray.astype(
597 598
                    weight_nparray.dtype
                )
599
                weight_tensor.set(weight_pruned_nparray, place)
600

601 602
                if with_mask:
                    weight_mask_param = global_scope().find_var(
603 604 605 606 607 608 609 610 611
                        ASPHelper._get_mask_name(param.name)
                    )
                    assert weight_mask_param is not None, (
                        'Cannot find {} variable, please call optimizer.minimize ('
                        'paddle.sparsity.decorate(optimizer).minimize(loss)'
                        ' and initialization (exe.run(startup_program)) first!'.format(
                            ASPHelper._get_mask_name(param.name)
                        )
                    )
612
                    weight_mask_tensor = weight_mask_param.get_tensor()
613
                    weight_sparse_mask = weight_sparse_mask.astype(
614 615
                        np.array(weight_mask_tensor).dtype
                    )
616 617 618 619
                    weight_mask_tensor.set(weight_sparse_mask, place)
                asp_info.update_masks(param.name, weight_sparse_mask)
        return asp_info.masks.copy()

620
    @classmethod
621 622 623 624 625 626 627 628 629
    def prune_model_by_layer(
        cls,
        place,
        layer,
        n=2,
        m=4,
        mask_algo=sparsity.MaskAlgo.MASK_1D,
        with_mask=True,
    ):
630 631 632 633 634 635 636 637 638 639 640 641 642
        r"""
        This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
        """
        if paddle.in_dynamic_mode():
            main_program = paddle.static.default_main_program()
            asp_info = cls._get_program_asp_info(main_program)

            for param in layer.parameters():
                if ASPHelper._is_supported_layer(main_program, param.name):
                    weight_nparray = param.numpy()

                    prune_func = ASPHelper._get_prune_func_by_name(param.name)

643 644 645
                    weight_pruned_nparray, weight_sparse_mask = prune_func(
                        weight_nparray, m, n, mask_algo, param.name
                    )
646 647

                    weight_pruned_nparray = weight_pruned_nparray.astype(
648 649
                        weight_nparray.dtype
                    )
650 651 652
                    param.set_value(weight_pruned_nparray)

                    if with_mask:
653
                        weight_mask_param = asp_info.mask_vars.get(
654 655 656 657 658 659 660 661
                            param.name, None
                        )
                        assert weight_mask_param is not None, (
                            'Cannot find {} variable, please call sparsity.decorate() to'
                            ' decorate your optimizer first!'.format(
                                ASPHelper._get_mask_name(param.name)
                            )
                        )
662 663 664 665 666 667 668 669 670 671 672
                        weight_mask_param.set_value(weight_sparse_mask)

                    asp_info.update_masks(param.name, weight_sparse_mask)

            return asp_info.masks.copy()
        else:
            # This for loop is only used to obtain Block and Program from
            # first parameters.
            target_program = None
            for param in layer.parameters():
                target_program = param.block.program
673 674 675 676 677 678 679 680 681 682 683
            assert (
                target_program is not None
            ), 'Cannot get paddle.static.Program from Paddle.nn.Layer.'
            return ASPHelper.prune_model_by_program(
                place,
                target_program,
                n=n,
                m=m,
                mask_algo=mask_algo,
                with_mask=with_mask,
            )
684

685 686 687 688 689 690 691 692 693 694
    @staticmethod
    def _get_mask_name(param_name):
        r"""
        Return mask name by given parameter name :attr:`param_name`.

        Args:
            param_name (string): The name of parameter.
        Returns:
            string: The mask name of :attr:`param_name`.
        """
695
        return param_name + "." + ASPHelper.MASK_APPENDDED_NAME
696 697 698 699 700 701 702 703 704 705 706 707 708

    @staticmethod
    def _get_not_ASP_relevant_vars(main_program):
        r"""
        Get all parameters's Variables in :attr:`main_program` but excluded ASP mask Variables.

        Args:
            main_program (Program): Program with model definition and its parameters.
        Returns:
            list: A list of parameter Variables in :attr:`main_program` (excluded ASP mask Variables).
        """
        var_list = []
        for param in main_program.global_block().all_parameters():
709 710 711
            param_name_list = param.name.split('.')

            if ASPHelper.MASK_APPENDDED_NAME not in param_name_list:
712 713 714 715 716
                var_list.append(param)
        return var_list

    @classmethod
    def _get_program_asp_info(cls, main_program):
717
        if main_program not in cls.__asp_info:
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
            cls.__asp_info[main_program] = ProgramASPInfo()
        return cls.__asp_info[main_program]

    @classmethod
    def _is_supported_layer(cls, main_program, param_name):
        r"""
        Verify if given :attr:`param_name` is supported by ASP.

        Args:
            param_name (string): The name of parameter.
        Returns:
            bool: True if it is supported, else False.
        Examples:
            .. code-block:: python

733
              from paddle.static.sparsity.asp import ASPHelper
734

735 736
              main_program = paddle.static.Program()
              startup_program = paddle.static.Program()
737

738 739 740
              with paddle.static.program_guard(main_program, startup_program):
                  input_data = paddle.static.data(name='data', shape=[None, 128])
                  fc = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None)
741 742 743 744 745 746

              for param in main_program.global_block().all_parameters():
                  ASPHelper._is_supported_layer(main_program, param.name)
              # fc_0.w_0 -> True
              # fc_0.b_0 -> False
        """
747 748 749
        param_name_list = param_name.split('.')

        if ASPHelper.MASK_APPENDDED_NAME in param_name_list:
750 751 752 753 754 755
            return False

        for layer in cls._get_program_asp_info(main_program).excluded_layers:
            if layer in param_name:
                return False

756 757 758
        if param_name in supported_layers_and_prune_func_map:
            return True

M
minghaoBD 已提交
759 760 761 762
        # The parameter's name is neither in *.* format nor added to supported_layers_and_prune_func_map, return False.
        if len(param_name_list) == 1:
            return False

763 764
        param_name_no_weight_suffix = param_name_list[0]
        param_type_suffix = param_name_list[1]
765 766 767
        layer_name = param_name_no_weight_suffix[
            : param_name_no_weight_suffix.rfind('_')
        ]
768 769 770
        if ASPHelper.PADDLE_WEIGHT_SUFFIX not in param_type_suffix:
            return False

771 772 773 774
        if (
            param_name_no_weight_suffix in supported_layers_and_prune_func_map
            or layer_name in supported_layers_and_prune_func_map
        ):
775 776
            return True

777 778
        return False

779 780 781 782 783 784
    @classmethod
    def _get_prune_func_by_name(cls, param_name):
        func = supported_layers_and_prune_func_map.get(param_name, None)
        param_name_no_weight_suffix = param_name.split('.')[0]
        if func is None:
            func = supported_layers_and_prune_func_map.get(
785 786
                param_name_no_weight_suffix, None
            )
787
        if func is None:
788 789 790
            layer_name = param_name_no_weight_suffix[
                : param_name_no_weight_suffix.rfind('_')
            ]
791
            func = supported_layers_and_prune_func_map.get(
792 793
                layer_name, _default_pruning
            )
794 795
        return func

796
    @classmethod
797 798 799 800 801 802 803 804 805
    def _minimize(
        cls,
        optimizer,
        loss,
        main_program=None,
        startup_program=None,
        parameter_list=None,
        no_grad_set=None,
    ):
806 807 808 809 810 811 812 813
        r"""
        This function is a decorator of `minimize` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.minimize(:attr:`loss`)
        2. Create sparse mask Tensors according to supported layers in :attr:`main_program`.
        3. Insert masking ops in the end of parameters update.

814 815
        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`.
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
            loss (Variable): A Variable containing the value to minimize.
            main_program (Program, optional): Program with model definition and its parameters. Default is `loss.block.program`.
            startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`.
            parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated.
            no_grad_set (set, optional): Set of `Variable  or `Variable.name` that don't need to be updated. The default value is None.
        Returns:
            list: operators from :attr:`optimizer`.minimize(:attr:`loss`).
            list: pairs of parameters and their gradients.
        """
        if main_program is None:
            main_program = loss.block.program

        if startup_program is None:
            startup_program = paddle.static.default_startup_program()

        optimizer_ops, params_and_grads = optimizer.minimize(
836 837
            loss, startup_program, parameter_list, no_grad_set=no_grad_set
        )
838 839 840 841

        params_only = [pg[0] for pg in params_and_grads]
        cls._create_mask_variables(main_program, startup_program, params_only)
        cls._insert_sparse_mask_ops(main_program, params_only)
842 843 844
        return optimizer_ops, params_and_grads

    @classmethod
845 846 847 848 849 850 851 852 853
    @dygraph_only
    def _step(cls, optimizer):
        r"""
        This function is a decorator of `step` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.step()
        2. Mask parameters with sparse masks.

854 855
        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`.
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph
856 857 858 859 860 861 862 863
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
        """
        optimizer.step()
        main_prog = paddle.static.default_main_program()
        with paddle.fluid.dygraph.no_grad():
864 865 866
            ASPHelper._insert_sparse_mask_ops(
                main_prog, optimizer._parameter_list
            )
867 868 869

    @classmethod
    def _create_mask_variables(cls, main_program, startup_program, params):
870 871 872 873 874 875 876
        r"""
        Create sparse mask Tensors according to supported layers in :attr:`main_program`.
        This function is called in second step of `ASPHelper._minimize`

        Args:
            main_program (Program): Program with model definition and its parameters.
            startup_program (Program): Program for initializing parameters.
877
            params (list): Variable parameters.
878 879 880
        """
        asp_info = cls._get_program_asp_info(main_program)
        with program_guard(main_program, startup_program):
881 882 883
            for param in params:
                if ASPHelper._is_supported_layer(main_program, param.name):
                    if param.name not in asp_info.mask_vars:
884
                        mask_param = paddle.create_parameter(
885 886 887
                            name=ASPHelper._get_mask_name(param.name),
                            shape=param.shape,
                            dtype=param.dtype,
888 889
                            default_initializer=ConstantInitializer(value=1.0),
                        )
890 891 892
                        mask_param.stop_gradient = True
                        mask_param.trainable = False
                        asp_info.update_mask_vars(param.name, mask_param)
893 894

    @classmethod
895
    def _insert_sparse_mask_ops(cls, main_program, params):
896 897 898 899 900 901
        r"""
        Insert masking ops in the end of parameters update.
        This function is called in third step of `ASPHelper._minimize`

        Args:
            main_program (Program): Program with model definition and its parameters.
902
            params (list): Variable parameters.
903 904 905
        """
        block = main_program.global_block()
        asp_info = cls._get_program_asp_info(main_program)
906 907
        for param in params:
            if param.name in asp_info.mask_vars:
908 909 910 911 912 913 914 915 916 917
                block.append_op(
                    type='elementwise_mul',
                    inputs={"X": param, 'Y': asp_info.mask_vars[param.name]},
                    outputs={'Out': param},
                    attrs={
                        'axis': -1,
                        'use_mkldnn': False,
                        OP_ROLE_KEY: int(OpRole.Optimize),
                    },
                )
918 919


920
class OptimizerWithSparsityGuarantee:
921 922 923 924 925 926 927 928 929 930
    r"""
    OptimizerWithSparsityGuarantee is a wrapper to decorate `minimize` function of given optimizer by `_minimize` of ASPHelper.
    The decorated `minimize` function would do three things (exactly same as `ASPHelper._minimize`):
    1. Call `minimize` function of given optimizer.
    2. Call `ASPHelper._create_mask_variables` to create mask Variables.
    3. Call `ASPHelper._insert_sparse_mask_ops` to insert weight masking ops in the end of `loss`'s Program.
    """

    def __init__(self, optimizer):
        self._optimizer = optimizer
931 932 933

    def __getattr__(self, item):
        return getattr(self._optimizer, item)
934

935 936 937
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
938 939 940 941 942 943 944 945 946 947 948 949
        r"""
        This function is to call `ASPHelper.minimize()` and return its return

        Args:
            loss (Variable): A Variable containing the value to minimize.
            startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`.
            parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated.
            no_grad_set (set, optional): Set of `Variable  or `Variable.name` that don't need to be updated. The default value is None.
        Returns:
            list: operators from :attr:`optimizer`.minimize(:attr:`loss`).
            list: pairs of parameters and their gradients.
        """
950 951 952 953 954 955 956
        return ASPHelper._minimize(
            self._optimizer,
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set,
        )
957 958 959 960 961 962 963 964 965 966

    @dygraph_only
    def step(self):
        r"""
        This function is a decorator of `step` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.step()
        2. Mask parameters with sparse masks.

967 968
        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`.
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
        """
        ASPHelper._step(self._optimizer)

    @dygraph_only
    def state_dict(self):
        r"""
        This function is a decorator of `state_dict` function in `Optimizer`.

        Returns:
            state_dict(dict) : dict contains all the Tensor used by optimizer
        """
        state_dict = self._optimizer.state_dict()
        asp_info = ASPHelper._get_program_asp_info(
986 987
            paddle.static.default_main_program()
        )
988 989 990 991 992 993 994 995
        for param_name, var in asp_info.mask_vars.items():
            state_dict.update({ASPHelper._get_mask_name(param_name): var})
        return state_dict

    @dygraph_only
    def set_state_dict(self, state_dict):
        r"""
        This function is a decorator of `set_state_dict` function in `Optimizer`.
996
        Args:
997 998 999 1000 1001
            state_dict(dict) : Dict contains all the Tensor needed by optimizer
        Return:
            None
        """
        asp_info = ASPHelper._get_program_asp_info(
1002 1003
            paddle.static.default_main_program()
        )
1004 1005
        for param_name, var in asp_info.mask_vars.items():
            param_mask_name = ASPHelper._get_mask_name(param_name)
1006 1007 1008
            assert param_mask_name in state_dict, "The {} is not found.".format(
                param_mask_name
            )
1009 1010 1011
            var.set_value(state_dict[param_mask_name])
            asp_info.update_masks(param_name, var.numpy())
        return self._optimizer.set_state_dict(state_dict)