asp.py 42.4 KB
Newer Older
1 2
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.
3
#
4 5 6
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Functions for Auto SParsity (ASP) training and inference.
"""

M
minghaoBD 已提交
19
import os
20 21 22
import copy
import numpy as np
import paddle
23
from paddle.fluid.framework import dygraph_only
24
from paddle.fluid import global_scope, program_guard, layers
25 26
from paddle.fluid.initializer import ConstantInitializer
from paddle.fluid.contrib import sparsity
27
from paddle.fluid import core
28 29
from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map
from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning
M
minghaoBD 已提交
30 31 32

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
33 34 35 36 37 38

__all__ = [
    'decorate', 'prune_model', 'set_excluded_layers', 'reset_excluded_layers'
]


39
def set_excluded_layers(param_names, main_program=None):
40 41 42 43
    r"""
    Set parameter name of layers which would not be pruned as sparse weights.

    Args:
44
        param_names (list of string): A list contains names of parameters.
45
        main_program (Program, optional): Program with model definition and its parameters.
46 47
                                          If None is given, then it would be set as `paddle.static.default_main_program().
                                          Default is None.
48
    Examples:
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        1. Usage of Dynamic Graph

            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

                # Need to set excluded layers before calling decorate
                paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])

                optimizer = paddle.incubate.asp.decorate(optimizer)

        2. Usage of Static Graph

            .. code-block:: python

                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    # Setup exluded layers out from ASP workflow.
                    # Please note, excluded_layers must be set before calling optimizer.minimize().
                    paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    optimizer = paddle.static.amp.decorate(optimizer )
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which 
                    # will insert necessary masking operations for ASP workflow.
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
118
    """
119 120
    if main_program is None:
        main_program = paddle.static.default_main_program()
121 122
    ASPHelper.set_excluded_layers(param_names=param_names,
                                  main_program=main_program)
123 124 125 126 127 128 129 130 131


def reset_excluded_layers(main_program=None):
    r"""
    Reset exculded layers setting corresponding to :attr:`main_program`. If :attr:`main_program` 
    is None, then all configurations of excluded_layers would be cleaned.

    Args:
        main_program (Program, optional): Program with model definition and its parameters.
132 133 134 135
                                          If None is given, then this function would reset all excluded_layers.
                                          Default is None.
    Examples:
        1. Usage of Dynamic Graph
136

137
            .. code-block:: python
138

139
                import paddle
140

141 142 143 144 145 146
                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)
147

148 149 150 151 152
                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction
153

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

                # Need to set excluded layers before calling decorate
                paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()])
                # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
                # Please note, reset_excluded_layers also must be called before calling sparsity.decorate().
                paddle.incubate.asp.reset_excluded_layers()

                optimizer = paddle.incubate.asp.decorate(optimizer)

        2. Usage of Static Graph

            .. code-block:: python
169

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    # Setup exluded layers out from ASP workflow.
                    # Please note, excluded_layers must be set before calling optimizer.minimize().
                    paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program)
                    # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow.
                    # Please note, reset_excluded_layers also must be called before calling optimizer.minimize().
                    paddle.incubate.asp.reset_excluded_layers(main_program)

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    optimizer = paddle.static.amp.decorate(optimizer )
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which 
                    # will insert necessary masking operations for ASP workflow.
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
210 211 212 213 214 215
    """
    ASPHelper.reset_excluded_layers(main_program=main_program)


def decorate(optimizer):
    r"""
216 217 218 219
    Wrap the given optimizer as a OptimizerWithSparsityGuarantee,
    If runnig with dynamic graph mode. ASP would creates mask variables for supported parameters.
    Else if in static graph mode, ASP would creates mask variables and inserts necessary ops 
    when calling minimize()
220 221 222 223 224 225

    Args:
        optimizer (Optimizer): A Optimizer used for training.
    Returns:
        OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer.
    Examples:
226
        1. Usage of Dynamic Graph
227

228 229 230 231 232 233 234 235 236 237 238 239
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)
240

241 242 243 244 245 246 247
                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction
248

249 250 251
                my_layer = MyLayer()
                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())
252

253 254 255 256
                # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which 
                # will apply necessary masking operations for ASP workflow.
                # In dynamic graph mode, ASP would create related mask variables during decoration.
                optimizer = paddle.incubate.asp.decorate(optimizer)
257

258
        2. Usage of Static Graph
259

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            .. code-block:: python

                import paddle

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 100)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        prediction = self.linear1(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224])
                    label = paddle.static.data(name='label', shape=[None, 100])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which 
                    # will insert necessary masking operations for ASP workflow.
                    # In static graph mode, ASP creates related mask variables 
                    # during minimize().
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)
296 297 298 299
    """
    return ASPHelper.decorate(optimizer)


300
def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True):
301
    r"""
302
    Pruning parameters of supported layers in :attr:`model` via 
303
    specified mask generation function given by :attr:`mask_algo`. This 
304 305 306 307
    function supports both training and inference controlled by :attr:`with_mask`.
    If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables,
    else only prunes parameters.

308
    *Note*: (Static graph mode) If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` 
309 310
    and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable). 
    Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for 
311
    inference only. To obtain OptimizerWithSparsityGuarantee, please see `paddle.incubate.asp.decoreate()`.
312 313

    Args:
314 315 316
        model (Program|nn.Layer): Program with model definition and its parameters, or a object of `paddle.nn.Layer`.
        n (int, optional): n of `n:m` sparse pattern. Default is 2.
        m (int, optional): m of `n:m` sparse pattern. Default is 4.
317 318
        mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`.
                                      The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'.
319
        with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Default is True.
320 321 322
    Returns:
        dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable.
    Examples:
323
        1. Usage of Dynamic Graph
324

325
            .. code-block:: python
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
                import paddle
                import numpy as np

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction

                my_layer = MyLayer()
                loss_fn = paddle.nn.MSELoss(reduction='mean')

                optimizer = paddle.optimizer.SGD(
                    learning_rate=0.01, parameters=my_layer.parameters())

                # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which 
                # will apply necessary masking operations for ASP workflow.
                # In dynamic graph mode, ASP would create related mask variables during decoration.
                optimizer = paddle.incubate.asp.decorate(optimizer)

                # Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model()
                paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')

                for i in range(10):
                    imgs = paddle.to_tensor(
                        np.random.randn(64, 3, 32, 32),
                        dtype='float32', stop_gradient=False)
                    labels = paddle.to_tensor(
                        np.random.randint(10, size=(64, 1)),
                        dtype='float32', stop_gradient=False)
                    output = my_layer(imgs)
                    loss = loss_fn(output, labels)
                    loss.backward()
                    optimizer.step()
                    optimizer.clear_grad()

        2. Usage of Static Graph
375

376
            .. code-block:: python
377

378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
                import paddle
                import numpy as np

                paddle.enable_static()

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self.conv1 = paddle.nn.Conv2D(
                            in_channels=3, out_channels=4, kernel_size=3, padding=2)
                        self.linear1 = paddle.nn.Linear(4624, 32)
                        self.linear2 = paddle.nn.Linear(32, 32)
                        self.linear3 = paddle.nn.Linear(32, 10)

                    def forward(self, img):
                        hidden = self.conv1(img)
                        hidden = paddle.flatten(hidden, start_axis=1)
                        hidden = self.linear1(hidden)
                        hidden = self.linear2(hidden)
                        prediction = self.linear3(hidden)
                        return prediction

                main_program = paddle.static.Program()
                startup_program = paddle.static.Program()

                with paddle.static.program_guard(main_program, startup_program):
                    input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32])
                    label = paddle.static.data(name='label', shape=[None, 1])
                    my_layer = MyLayer()
                    prob = my_layer(input_data)
                    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

                    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                    # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which 
                    # will insert necessary masking operations for ASP workflow.
                    # In static graph mode, ASP creates related mask variables 
                    # during minimize().
                    optimizer = paddle.incubate.asp.decorate(optimizer)
                    optimizer.minimize(loss, startup_program)

                device = paddle.device.get_device()
                place = paddle.set_device(device)

                exe = paddle.static.Executor(place)
                exe.run(startup_program)

                # Must call exe.run(startup_program) first before calling paddle.asp.prune_model()
                paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best')
                # it also be accepted to call 
                # paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best')

                for i in range(10):
                    imgs = np.random.randn(64, 3, 32, 32).astype('float32')
                    labels = np.random.randint(10, size=(64, 1)).astype('float32')
                    exe.run(main_program, feed={'data':imgs, 'label':labels})
433
    """
434 435
    device = paddle.device.get_device()
    place = paddle.set_device(device)
436 437 438 439 440 441 442

    MaskAlgo_mapping = {
        'mask_1d': sparsity.MaskAlgo.MASK_1D,
        'mask_2d_greedy': sparsity.MaskAlgo.MASK_2D_GREEDY,
        'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST
    }
    assert (mask_algo in MaskAlgo_mapping), \
443 444 445 446 447 448 449 450 451 452 453 454 455 456
        'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]'

    prune_func = None
    if isinstance(model, paddle.nn.Layer):
        prune_func = ASPHelper.prune_model_by_layer
    elif isinstance(model, paddle.static.Program):
        prune_func = ASPHelper.prune_model_by_program
        if hasattr(model, "distributed_info_") and \
           model.distributed_info_["sharding_degree"] > 1 and \
           paddle.fluid.is_compiled_with_cuda():
            gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
            place = paddle.CUDAPlace(gpu_id)
    else:
        raise TypeError(
457 458
            "model should be paddle.nn.Layer or paddle.static.Program, but got {}"
            .format(type(model)))
459

460 461 462 463 464 465
    return prune_func(place,
                      model,
                      n=n,
                      m=m,
                      mask_algo=MaskAlgo_mapping[mask_algo],
                      with_mask=with_mask)
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513


class ProgramASPInfo(object):
    r"""
    ProgramASPInfo is a container to keep ASP relevant information of Pragrom. It contains three inner-variables:
    1. __mask_vars (Dictionary): Key is parameter's name and vaule is its corresponding sparse mask Variable object, which is created by `ASPHelper.create_mask_variables`.
    2. __masks (Dictionary): Key is parameter's name and vaule is its corressponding sparse mask Numpy array, which is created by `ASPHelper.prune_model`.
    3. __excluded_layers (List): It stores name of layers which should not involve into ASP workflow.
    """

    def __init__(self):
        self.__mask_vars = {}
        self.__masks = {}
        self.__excluded_layers = []

    def update_mask_vars(self, param_name, var):
        self.__mask_vars[param_name] = var

    def update_masks(self, param_name, var):
        self.__masks[param_name] = var

    def update_excluded_layers(self, param_names):
        self.__excluded_layers.extend(copy.deepcopy(param_names))

    def reset_excluded_layers(self):
        self.__excluded_layers = []

    @property
    def mask_vars(self):
        return self.__mask_vars

    @property
    def masks(self):
        return self.__masks

    @property
    def excluded_layers(self):
        return self.__excluded_layers


class ASPHelper(object):
    r"""
    ASPHelper is a collection of Auto SParsity (ASP) functions to enable 

    1. training models with weights in 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 from scratch.
    2. pruning well-trained models into 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 for fine-tuning.
    """

514 515
    MASK_APPENDDED_NAME = 'asp_mask'
    PADDLE_WEIGHT_SUFFIX = "w_"
516 517 518 519

    __asp_info = {}

    @classmethod
520
    def set_excluded_layers(cls, param_names, main_program):
521 522 523 524 525 526 527 528 529 530 531 532
        r"""
        This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`.
        """
        asp_info = cls._get_program_asp_info(main_program)
        asp_info.update_excluded_layers(param_names)

    @classmethod
    def reset_excluded_layers(cls, main_program=None):
        r"""
        This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`.
        """
        if main_program is None:
533 534
            for prog in cls.__asp_info:
                cls.__asp_info[prog].reset_excluded_layers()
535 536 537 538 539 540 541 542
        else:
            cls._get_program_asp_info(main_program).reset_excluded_layers()

    @staticmethod
    def decorate(optimizer):
        r"""
        This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`.
        """
543 544 545 546 547 548 549 550 551
        if paddle.in_dynamic_mode():
            # main_prog and startup_prog would be used with paddle.static.program_guard
            # to create ASP masks. Moreover, main_prog is a key to map paddle.static.Program
            # to its own ASP informantion, like ASP mask variables. For dynamic graph, we use
            # default_main_program as the key.
            main_prog = paddle.static.default_main_program()
            startup_prog = paddle.static.default_startup_program()
            ASPHelper._create_mask_variables(main_prog, startup_prog,
                                             optimizer._parameter_list)
552 553 554
        return OptimizerWithSparsityGuarantee(optimizer)

    @classmethod
555 556 557 558 559 560 561
    def prune_model_by_program(cls,
                               place,
                               main_program=None,
                               n=2,
                               m=4,
                               mask_algo=sparsity.MaskAlgo.MASK_1D,
                               with_mask=True):
562 563 564 565 566 567 568 569 570 571 572 573 574
        r"""
        This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
        """

        if main_program is None:
            main_program = paddle.static.default_main_program()

        asp_info = cls._get_program_asp_info(main_program)
        for param in main_program.global_block().all_parameters():
            if ASPHelper._is_supported_layer(main_program, param.name):
                weight_tensor = global_scope().find_var(param.name).get_tensor()
                weight_nparray = np.array(weight_tensor)

575 576 577 578 579 580
                prune_func = ASPHelper._get_prune_func_by_name(param.name)

                weight_pruned_nparray, weight_sparse_mask = \
                    prune_func(weight_nparray, m, n, mask_algo, param.name)
                weight_pruned_nparray = weight_pruned_nparray.astype(
                    weight_nparray.dtype)
581
                weight_tensor.set(weight_pruned_nparray, place)
582

583 584 585 586
                if with_mask:
                    weight_mask_param = global_scope().find_var(
                        ASPHelper._get_mask_name(param.name))
                    assert weight_mask_param is not None, \
587 588
                        'Cannot find {} variable, please call optimizer.minimize (' \
                        'paddle.sparsity.decorate(optimizer).minimize(loss)' \
589 590
                        ' and initialization (exe.run(startup_program)) first!'.format(ASPHelper._get_mask_name(param.name))
                    weight_mask_tensor = weight_mask_param.get_tensor()
591 592
                    weight_sparse_mask = weight_sparse_mask.astype(
                        np.array(weight_mask_tensor).dtype)
593 594 595 596
                    weight_mask_tensor.set(weight_sparse_mask, place)
                asp_info.update_masks(param.name, weight_sparse_mask)
        return asp_info.masks.copy()

597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
    @classmethod
    def prune_model_by_layer(cls,
                             place,
                             layer,
                             n=2,
                             m=4,
                             mask_algo=sparsity.MaskAlgo.MASK_1D,
                             with_mask=True):
        r"""
        This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`.
        """
        if paddle.in_dynamic_mode():
            main_program = paddle.static.default_main_program()
            asp_info = cls._get_program_asp_info(main_program)

            for param in layer.parameters():
                if ASPHelper._is_supported_layer(main_program, param.name):
                    weight_nparray = param.numpy()

                    prune_func = ASPHelper._get_prune_func_by_name(param.name)

                    weight_pruned_nparray, weight_sparse_mask = \
                        prune_func(weight_nparray, m, n, mask_algo, param.name)

                    weight_pruned_nparray = weight_pruned_nparray.astype(
                        weight_nparray.dtype)
                    param.set_value(weight_pruned_nparray)

                    if with_mask:
626 627
                        weight_mask_param = asp_info.mask_vars.get(
                            param.name, None)
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
                        assert weight_mask_param is not None, \
                            'Cannot find {} variable, please call sparsity.decorate() to' \
                            ' decorate your optimizer first!'.format(ASPHelper._get_mask_name(param.name))
                        weight_mask_param.set_value(weight_sparse_mask)

                    asp_info.update_masks(param.name, weight_sparse_mask)

            return asp_info.masks.copy()
        else:
            # This for loop is only used to obtain Block and Program from
            # first parameters.
            target_program = None
            for param in layer.parameters():
                target_program = param.block.program
            assert target_program is not None, \
                    'Cannot get paddle.static.Program from Paddle.nn.Layer.'
644 645 646 647 648 649
            return ASPHelper.prune_model_by_program(place,
                                                    target_program,
                                                    n=n,
                                                    m=m,
                                                    mask_algo=mask_algo,
                                                    with_mask=with_mask)
650

651 652 653 654 655 656 657 658 659 660
    @staticmethod
    def _get_mask_name(param_name):
        r"""
        Return mask name by given parameter name :attr:`param_name`.

        Args:
            param_name (string): The name of parameter.
        Returns:
            string: The mask name of :attr:`param_name`.
        """
661
        return param_name + "." + ASPHelper.MASK_APPENDDED_NAME
662 663 664 665 666 667 668 669 670 671 672 673 674

    @staticmethod
    def _get_not_ASP_relevant_vars(main_program):
        r"""
        Get all parameters's Variables in :attr:`main_program` but excluded ASP mask Variables.

        Args:
            main_program (Program): Program with model definition and its parameters.
        Returns:
            list: A list of parameter Variables in :attr:`main_program` (excluded ASP mask Variables).
        """
        var_list = []
        for param in main_program.global_block().all_parameters():
675 676 677
            param_name_list = param.name.split('.')

            if ASPHelper.MASK_APPENDDED_NAME not in param_name_list:
678 679 680 681 682
                var_list.append(param)
        return var_list

    @classmethod
    def _get_program_asp_info(cls, main_program):
683
        if main_program not in cls.__asp_info:
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
            cls.__asp_info[main_program] = ProgramASPInfo()
        return cls.__asp_info[main_program]

    @classmethod
    def _is_supported_layer(cls, main_program, param_name):
        r"""
        Verify if given :attr:`param_name` is supported by ASP.

        Args:
            param_name (string): The name of parameter.
        Returns:
            bool: True if it is supported, else False.
        Examples:
            .. code-block:: python

699
              from paddle.static.sparsity.asp import ASPHelper
700

701 702
              main_program = paddle.static.Program()
              startup_program = paddle.static.Program()
703

704 705 706
              with paddle.static.program_guard(main_program, startup_program):
                  input_data = paddle.static.data(name='data', shape=[None, 128])
                  fc = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None)
707 708 709 710 711 712

              for param in main_program.global_block().all_parameters():
                  ASPHelper._is_supported_layer(main_program, param.name)
              # fc_0.w_0 -> True
              # fc_0.b_0 -> False
        """
713 714 715
        param_name_list = param_name.split('.')

        if ASPHelper.MASK_APPENDDED_NAME in param_name_list:
716 717 718 719 720 721
            return False

        for layer in cls._get_program_asp_info(main_program).excluded_layers:
            if layer in param_name:
                return False

722 723 724
        if param_name in supported_layers_and_prune_func_map:
            return True

M
minghaoBD 已提交
725 726 727 728
        # The parameter's name is neither in *.* format nor added to supported_layers_and_prune_func_map, return False.
        if len(param_name_list) == 1:
            return False

729 730 731 732 733 734 735 736 737 738 739
        param_name_no_weight_suffix = param_name_list[0]
        param_type_suffix = param_name_list[1]
        layer_name = param_name_no_weight_suffix[:param_name_no_weight_suffix.
                                                 rfind('_')]
        if ASPHelper.PADDLE_WEIGHT_SUFFIX not in param_type_suffix:
            return False

        if param_name_no_weight_suffix in supported_layers_and_prune_func_map or \
            layer_name in supported_layers_and_prune_func_map:
            return True

740 741
        return False

742 743 744 745 746 747 748 749 750
    @classmethod
    def _get_prune_func_by_name(cls, param_name):
        func = supported_layers_and_prune_func_map.get(param_name, None)
        param_name_no_weight_suffix = param_name.split('.')[0]
        if func is None:
            func = supported_layers_and_prune_func_map.get(
                param_name_no_weight_suffix, None)
        if func is None:
            layer_name = param_name_no_weight_suffix[:
751 752 753 754
                                                     param_name_no_weight_suffix
                                                     .rfind('_')]
            func = supported_layers_and_prune_func_map.get(
                layer_name, _default_pruning)
755 756
        return func

757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795
    @classmethod
    def _minimize(cls,
                  optimizer,
                  loss,
                  main_program=None,
                  startup_program=None,
                  parameter_list=None,
                  no_grad_set=None):
        r"""
        This function is a decorator of `minimize` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.minimize(:attr:`loss`)
        2. Create sparse mask Tensors according to supported layers in :attr:`main_program`.
        3. Insert masking ops in the end of parameters update.

        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. 
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph 
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
            loss (Variable): A Variable containing the value to minimize.
            main_program (Program, optional): Program with model definition and its parameters. Default is `loss.block.program`.
            startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`.
            parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated.
            no_grad_set (set, optional): Set of `Variable  or `Variable.name` that don't need to be updated. The default value is None.
        Returns:
            list: operators from :attr:`optimizer`.minimize(:attr:`loss`).
            list: pairs of parameters and their gradients.
        """
        if main_program is None:
            main_program = loss.block.program

        if startup_program is None:
            startup_program = paddle.static.default_startup_program()

        optimizer_ops, params_and_grads = optimizer.minimize(
            loss, startup_program, parameter_list, no_grad_set=no_grad_set)
796 797 798 799

        params_only = [pg[0] for pg in params_and_grads]
        cls._create_mask_variables(main_program, startup_program, params_only)
        cls._insert_sparse_mask_ops(main_program, params_only)
800 801 802
        return optimizer_ops, params_and_grads

    @classmethod
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
    @dygraph_only
    def _step(cls, optimizer):
        r"""
        This function is a decorator of `step` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.step()
        2. Mask parameters with sparse masks.

        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. 
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph 
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
        """
        optimizer.step()
        main_prog = paddle.static.default_main_program()
        with paddle.fluid.dygraph.no_grad():
            ASPHelper._insert_sparse_mask_ops(main_prog,
                                              optimizer._parameter_list)

    @classmethod
    def _create_mask_variables(cls, main_program, startup_program, params):
827 828 829 830 831 832 833
        r"""
        Create sparse mask Tensors according to supported layers in :attr:`main_program`.
        This function is called in second step of `ASPHelper._minimize`

        Args:
            main_program (Program): Program with model definition and its parameters.
            startup_program (Program): Program for initializing parameters.
834
            params (list): Variable parameters.
835 836 837
        """
        asp_info = cls._get_program_asp_info(main_program)
        with program_guard(main_program, startup_program):
838 839 840 841 842 843 844 845 846 847 848
            for param in params:
                if ASPHelper._is_supported_layer(main_program, param.name):
                    if param.name not in asp_info.mask_vars:
                        mask_param = layers.create_parameter(
                            name=ASPHelper._get_mask_name(param.name),
                            shape=param.shape,
                            dtype=param.dtype,
                            default_initializer=ConstantInitializer(value=1.0))
                        mask_param.stop_gradient = True
                        mask_param.trainable = False
                        asp_info.update_mask_vars(param.name, mask_param)
849 850

    @classmethod
851
    def _insert_sparse_mask_ops(cls, main_program, params):
852 853 854 855 856 857
        r"""
        Insert masking ops in the end of parameters update.
        This function is called in third step of `ASPHelper._minimize`

        Args:
            main_program (Program): Program with model definition and its parameters.
858
            params (list): Variable parameters.
859 860 861
        """
        block = main_program.global_block()
        asp_info = cls._get_program_asp_info(main_program)
862 863
        for param in params:
            if param.name in asp_info.mask_vars:
864 865 866 867 868 869 870 871 872 873 874
                block.append_op(type='elementwise_mul',
                                inputs={
                                    "X": param,
                                    'Y': asp_info.mask_vars[param.name]
                                },
                                outputs={'Out': param},
                                attrs={
                                    'axis': -1,
                                    'use_mkldnn': False,
                                    OP_ROLE_KEY: int(OpRole.Optimize)
                                })
875 876 877 878 879 880 881 882 883 884 885 886 887


class OptimizerWithSparsityGuarantee(object):
    r"""
    OptimizerWithSparsityGuarantee is a wrapper to decorate `minimize` function of given optimizer by `_minimize` of ASPHelper.
    The decorated `minimize` function would do three things (exactly same as `ASPHelper._minimize`):
    1. Call `minimize` function of given optimizer.
    2. Call `ASPHelper._create_mask_variables` to create mask Variables.
    3. Call `ASPHelper._insert_sparse_mask_ops` to insert weight masking ops in the end of `loss`'s Program.
    """

    def __init__(self, optimizer):
        self._optimizer = optimizer
888 889 890

    def __getattr__(self, item):
        return getattr(self._optimizer, item)
891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        r"""
        This function is to call `ASPHelper.minimize()` and return its return

        Args:
            loss (Variable): A Variable containing the value to minimize.
            startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`.
            parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated.
            no_grad_set (set, optional): Set of `Variable  or `Variable.name` that don't need to be updated. The default value is None.
        Returns:
            list: operators from :attr:`optimizer`.minimize(:attr:`loss`).
            list: pairs of parameters and their gradients.
        """
909 910 911 912 913
        return ASPHelper._minimize(self._optimizer,
                                   loss,
                                   startup_program=startup_program,
                                   parameter_list=parameter_list,
                                   no_grad_set=no_grad_set)
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965

    @dygraph_only
    def step(self):
        r"""
        This function is a decorator of `step` function in `Optimizer`.
        There are three steps:

        1. Call :attr:`optimizer`.step()
        2. Mask parameters with sparse masks.

        *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. 
        (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph 
        cannot be modified anymore.)

        Args:
            optimizer (Optimizer): A Optimizer used for training.
        """
        ASPHelper._step(self._optimizer)

    @dygraph_only
    def state_dict(self):
        r"""
        This function is a decorator of `state_dict` function in `Optimizer`.

        Returns:
            state_dict(dict) : dict contains all the Tensor used by optimizer
        """
        state_dict = self._optimizer.state_dict()
        asp_info = ASPHelper._get_program_asp_info(
            paddle.static.default_main_program())
        for param_name, var in asp_info.mask_vars.items():
            state_dict.update({ASPHelper._get_mask_name(param_name): var})
        return state_dict

    @dygraph_only
    def set_state_dict(self, state_dict):
        r"""
        This function is a decorator of `set_state_dict` function in `Optimizer`.
        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer
        Return:
            None
        """
        asp_info = ASPHelper._get_program_asp_info(
            paddle.static.default_main_program())
        for param_name, var in asp_info.mask_vars.items():
            param_mask_name = ASPHelper._get_mask_name(param_name)
            assert param_mask_name in state_dict, \
                "The {} is not found.".format(param_mask_name)
            var.set_value(state_dict[param_mask_name])
            asp_info.update_masks(param_name, var.numpy())
        return self._optimizer.set_state_dict(state_dict)