From 09bc9c06470aa9c5c2a5d3e6022b849dc460674c Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Sat, 30 Oct 2021 00:14:39 +0800 Subject: [PATCH] Move the ASP training API to paddle.static.sparsity. (#36525) (#36860) Cherry-pick #36525 --- .../paddle/fluid/contrib/sparsity/__init__.py | 6 +- python/paddle/fluid/contrib/sparsity/asp.py | 162 ++++++++++++------ python/paddle/fluid/contrib/sparsity/utils.py | 7 +- .../tests/unittests/asp/asp_pruning_base.py | 9 +- .../tests/unittests/asp/test_asp_optimize.py | 14 +- .../unittests/asp/test_asp_pruning_1d.py | 10 +- .../unittests/asp/test_asp_pruning_2d_best.py | 10 +- .../asp/test_asp_pruning_2d_greedy.py | 12 +- .../tests/unittests/asp/test_asp_utils.py | 136 +++++++++------ .../unittests/asp/test_fleet_with_asp.py | 8 +- .../unittests/asp/test_fleet_with_asp_amp.py | 14 +- python/paddle/static/__init__.py | 2 + python/paddle/static/sparsity/__init__.py | 28 +++ python/setup.py.in | 1 + 14 files changed, 277 insertions(+), 142 deletions(-) create mode 100644 python/paddle/static/sparsity/__init__.py diff --git a/python/paddle/fluid/contrib/sparsity/__init__.py b/python/paddle/fluid/contrib/sparsity/__init__.py index b36a79b8ca8..9bf45f42727 100644 --- a/python/paddle/fluid/contrib/sparsity/__init__.py +++ b/python/paddle/fluid/contrib/sparsity/__init__.py @@ -25,8 +25,10 @@ from .utils import create_mask from .utils import check_sparsity from .utils import MaskAlgo from .utils import CheckMethod -from .asp import decorate, prune_model -from .asp import set_excluded_layers, reset_excluded_layers +from .asp import decorate +from .asp import prune_model +from .asp import set_excluded_layers +from .asp import reset_excluded_layers __all__ = [ 'calculate_density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d', diff --git a/python/paddle/fluid/contrib/sparsity/asp.py b/python/paddle/fluid/contrib/sparsity/asp.py index 77c61faf23d..61e3a61fc9c 100644 --- a/python/paddle/fluid/contrib/sparsity/asp.py +++ b/python/paddle/fluid/contrib/sparsity/asp.py @@ -19,10 +19,9 @@ Functions for Auto SParsity (ASP) training and inference. import copy import numpy as np import paddle -from paddle.fluid import framework, global_scope, program_guard, layers +from paddle.fluid import global_scope, program_guard, layers from paddle.fluid.initializer import ConstantInitializer from paddle.fluid.contrib import sparsity -from paddle.fluid import core __all__ = [ 'decorate', 'prune_model', 'set_excluded_layers', 'reset_excluded_layers' @@ -36,6 +35,35 @@ def set_excluded_layers(main_program, param_names): Args: main_program (Program, optional): Program with model definition and its parameters. param_names (list): A list contains names of parameters. + Examples: + .. code-block:: python + + import paddle + from paddle.static import sparsity + + paddle.enable_static() + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 128]) + label = paddle.static.data(name='label', shape=[None, 10]) + hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc") + hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc") + prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + # Setup exluded layers out from ASP workflow. + # Please note, excluded_layers must be set before calling `optimizer.minimize()`. + sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) + + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + optimizer = paddle.static.amp.decorate(optimizer ) + # Calling sparsity.decorate() to wrap minimize() in optimizer, which + # will insert necessary masking operations for ASP workflow. + optimizer = sparsity.decorate(optimizer) + optimizer.minimize(loss, startup_program) """ ASPHelper.set_excluded_layers( main_program=main_program, param_names=param_names) @@ -48,6 +76,33 @@ def reset_excluded_layers(main_program=None): Args: main_program (Program, optional): Program with model definition and its parameters. + Examples: + .. code-block:: python + + import paddle + from paddle.static import sparsity + + paddle.enable_static() + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 128]) + label = paddle.static.data(name='label', shape=[None, 10]) + hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="my_first_fc") + hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="my_second_fc") + prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + # Setup exluded layers out from ASP workflow. + # Please note, excluded_layers must be set before calling `optimizer.minimize()`. + sparsity.set_excluded_layers(main_program, ["my_second_fc"]) + # Now the weights of "my_second_fc" would not be included in Automatic SParsity's workflow. + + # Reset excluded_layers, all FC layers would be included into Automatic SParsity's workflow. + # Please note, reset_excluded_layers also must be called before calling `optimizer.minimize()`. + sparsity.reset_excluded_layers(main_program) """ ASPHelper.reset_excluded_layers(main_program=main_program) @@ -65,22 +120,21 @@ def decorate(optimizer): .. code-block:: python import paddle - import paddle.fluid as fluid - from paddle.fluid.contrib import sparsity + from paddle.static import sparsity - main_program = fluid.Program() - startup_program = fluid.Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() paddle.enable_static() - with fluid.program_guard(main_program, startup_program): - input_data = fluid.layers.data(name='data', shape=[None, 128]) - label = fluid.layers.data(name='label', shape=[None, 10]) - hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None) - prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None) - loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label)) + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 128]) + label = paddle.static.data(name='label', shape=[None, 10]) + hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None) + prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - optimizer = fluid.optimizer.SGD(learning_rate=0.1) + optimizer = paddle.optimizer.SGD(learning_rate=0.1) optimizer = sparsity.decorate(optimizer) # if do sparse training with Fleet, please replace above decorate with: # strategy = paddle.distributed.fleet.DistributedStrategy() @@ -92,15 +146,14 @@ def decorate(optimizer): return ASPHelper.decorate(optimizer) -def prune_model(place, - main_program=None, +def prune_model(main_program=None, n=2, m=4, - func_name=sparsity.MaskAlgo.MASK_1D, + mask_algo='mask_1d', with_mask=True): r""" Pruning parameters of supported layers in :attr:`main_program` via - specified mask generation function given by :attr:`func_name`. This + specified mask generation function given by :attr:`mask_algo`. This function supports both training and inference controlled by :attr:`with_mask`. If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables, else only prunes parameters. @@ -114,11 +167,11 @@ def prune_model(place, inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. Args: - place (fluid.CPUPlace()|fluid.CUDAPlace(N)): Device place for pruned parameter and mask Variables, and N means the GPU's id. It should be the same as created instance of Executor. main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() n (int): n of `n:m` sparse pattern. m (int): m of `n:m` sparse pattern. - func_name (MaskAlgo, optional): The function name to generate spase mask. Default is `MaskAlgo.MASK_1D`. All options please refer to `MaskAlgo`. + mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`. + The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'. with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True. Returns: dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. @@ -126,50 +179,58 @@ def prune_model(place, .. code-block:: python import paddle - import paddle.fluid as fluid - import paddle.fluid.core as core - from paddle.fluid.contrib import sparsity + from paddle.static import sparsity paddle.enable_static() - main_program = fluid.Program() - startup_program = fluid.Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - place = paddle.CPUPlace() - if core.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - - with fluid.program_guard(main_program, startup_program): - input_data = fluid.layers.data(name='data', shape=[None, 128]) - label = fluid.layers.data(name='label', shape=[None, 10]) - hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None, name="need_sparse") - hidden = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=32, act=None, name="need_dense") - prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None) - loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label)) + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 128]) + label = paddle.static.data(name='label', shape=[None, 10]) + hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc") + hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc") + prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) # Setup exluded layers out from ASP workflow. # Please note, excluded_layers must be set before calling `optimizer.minimize()`. - sparsity.set_excluded_layers(main_program, ["need_dense"]) + sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) - optimizer = fluid.optimizer.SGD(learning_rate=0.1) - optimizer = fluid.contrib.mixed_precision.decorator.decorate(optimizer ) + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + optimizer = paddle.static.amp.decorate(optimizer ) # Calling sparsity.decorate() to wrap minimize() in optimizer, which # will insert necessary masking operations for ASP workflow. optimizer = sparsity.decorate(optimizer) optimizer.minimize(loss, startup_program) - exe = fluid.Executor(place) + device = paddle.device.get_device() + place = paddle.set_device(device) + + exe = paddle.static.Executor(place) exe.run(startup_program) # Must call `exe.run(startup_program)` first before calling `sparsity.prune_model` - sparsity.prune_model(place, main_program, func_name=sparsity.MaskAlgo.MASK_2D_BEST) + sparsity.prune_model(main_program, mask_algo='mask_2d_best') """ + device = paddle.device.get_device() + place = paddle.set_device(device) + + MaskAlgo_mapping = { + 'mask_1d': sparsity.MaskAlgo.MASK_1D, + 'mask_2d_greedy': sparsity.MaskAlgo.MASK_2D_GREEDY, + 'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST + } + assert (mask_algo in MaskAlgo_mapping), \ + 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' + return ASPHelper.prune_model( place=place, main_program=main_program, n=n, m=m, - func_name=func_name, + mask_algo=MaskAlgo_mapping[mask_algo], with_mask=with_mask) @@ -256,12 +317,12 @@ class ASPHelper(object): main_program=None, n=2, m=4, - func_name=sparsity.MaskAlgo.MASK_1D, + mask_algo=sparsity.MaskAlgo.MASK_1D, with_mask=True): r""" This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. """ - checked_func_name = sparsity.CheckMethod.get_checking_method(func_name) + checked_func_name = sparsity.CheckMethod.get_checking_method(mask_algo) if main_program is None: main_program = paddle.static.default_main_program() @@ -284,7 +345,7 @@ class ASPHelper(object): # matrices beforce invoking create_mask. Then we transpose the result maks to make # sure its shape to be the same as the input weight. weight_sparse_mask = sparsity.create_mask( - weight_nparray.T, func_name=func_name, n=n, m=m).T + weight_nparray.T, func_name=mask_algo, n=n, m=m).T weight_pruned_nparray = np.multiply(weight_nparray, weight_sparse_mask) weight_tensor.set(weight_pruned_nparray, place) @@ -347,15 +408,14 @@ class ASPHelper(object): Examples: .. code-block:: python - import paddle.fluid as fluid - from paddle.fluid.contrib.sparsity.asp import ASPHelper + from paddle.static.sparsity.asp import ASPHelper - main_program = fluid.Program() - startup_program = fluid.Program() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() - with fluid.program_guard(main_program, startup_program): - input_data = fluid.layers.data(name='data', shape=[None, 128]) - fc = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None) + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 128]) + fc = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None) for param in main_program.global_block().all_parameters(): ASPHelper._is_supported_layer(main_program, param.name) diff --git a/python/paddle/fluid/contrib/sparsity/utils.py b/python/paddle/fluid/contrib/sparsity/utils.py index bb030cbac1b..e1f68a6fef0 100644 --- a/python/paddle/fluid/contrib/sparsity/utils.py +++ b/python/paddle/fluid/contrib/sparsity/utils.py @@ -64,7 +64,8 @@ class CheckMethod(Enum): .. code-block:: python import numpy as np - from paddle.fluid.contrib.sparsity import MaskAlgo, CheckMethod + from paddle.static.sparsity import MaskAlgo + from paddle.fluid.contrib.sparsity import CheckMethod CheckMethod.get_checking_method(MaskAlgo.MASK_1D) # CheckMethod.CHECK_1D @@ -95,7 +96,7 @@ def calculate_density(x): .. code-block:: python import numpy as np - import paddle.fluid.contrib.sparsity as sparsity + import paddle.static.sparsity as sparsity x = np.array([[0, 1, 3, 0], [1, 1, 0, 1]]) @@ -446,7 +447,7 @@ def get_mask_2d_best(mat, n, m): [5, 6, 3, 9], [2, 4, 6, 9]]) mask_greedy = sparsity.get_mask_2d_greedy(mat, 2, 4) - mask_greedy = sparsity.get_mask_2d_best(mat, 2, 4) + mask_best = sparsity.get_mask_2d_best(mat, 2, 4) print("L1 norm of `greedy` sparse matrix", np.multiply(mat, mask_greedy).sum()) # 56 print("L1 norm of `best` sparse matrix", np.multiply(mat, mask_best).sum()) # 61 """ diff --git a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py index 370d73cc35a..d41a7b2b842 100644 --- a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py +++ b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py @@ -20,7 +20,7 @@ import threading, time import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np @@ -76,14 +76,11 @@ class TestASPHelperPruningBase(unittest.TestCase): check_func_name, with_mask): exe.run(self.startup_program) sparsity.prune_model( - place, - self.main_program, - func_name=mask_func_name, - with_mask=with_mask) + self.main_program, mask_algo=mask_func_name, with_mask=with_mask) for param in self.main_program.global_block().all_parameters(): if ASPHelper._is_supported_layer(self.main_program, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) self.assertTrue( - sparsity.check_sparsity( + paddle.fluid.contrib.sparsity.check_sparsity( mat.T, func_name=check_func_name, n=2, m=4)) diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py index 402861ad5d9..9e5e3c924f1 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py @@ -20,7 +20,7 @@ import threading, time import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np @@ -129,7 +129,7 @@ class TestASPHelper(unittest.TestCase): feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place) exe.run(self.startup_program) - sparsity.prune_model(place, self.main_program) + sparsity.prune_model(self.main_program) data = (np.random.randn(64, 3, 32, 32), np.random.randint( 10, size=(64, 1))) @@ -139,7 +139,9 @@ class TestASPHelper(unittest.TestCase): if ASPHelper._is_supported_layer(self.main_program, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) - self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) def test_asp_training_with_amp(self): if core.is_compiled_with_cuda(): @@ -155,7 +157,7 @@ class TestASPHelper(unittest.TestCase): feed_list=[self.img, self.label], place=place) exe.run(self.startup_program) - sparsity.prune_model(place, self.main_program) + sparsity.prune_model(self.main_program) data = (np.random.randn(64, 3, 32, 32), np.random.randint( 10, size=(64, 1))) @@ -165,7 +167,9 @@ class TestASPHelper(unittest.TestCase): if ASPHelper._is_supported_layer(self.main_program, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) - self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) def __get_param_names(self, params): param_names = [] diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py index 6ebc89b1873..7a3fa024493 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py @@ -17,7 +17,7 @@ from __future__ import print_function import unittest import paddle -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase paddle.enable_static() @@ -25,12 +25,12 @@ paddle.enable_static() class TestASPHelperPruning1D(TestASPHelperPruningBase): def test_1D_inference_pruning(self): - self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_1D, - sparsity.CheckMethod.CHECK_1D) + self.run_inference_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) def test_1D_training_pruning(self): - self.run_training_pruning_test(sparsity.MaskAlgo.MASK_1D, - sparsity.CheckMethod.CHECK_1D) + self.run_training_pruning_test( + 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py index b21f8edf4f4..e9950918703 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py @@ -17,7 +17,7 @@ from __future__ import print_function import paddle import unittest -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase paddle.enable_static() @@ -25,12 +25,12 @@ paddle.enable_static() class TestASPHelperPruning2DBest(TestASPHelperPruningBase): def test_2D_best_inference_pruning(self): - self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_2D_BEST, - sparsity.CheckMethod.CHECK_2D) + self.run_inference_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) def test_2D_best_training_pruning(self): - self.run_training_pruning_test(sparsity.MaskAlgo.MASK_2D_BEST, - sparsity.CheckMethod.CHECK_2D) + self.run_training_pruning_test( + 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py index 8ec8ab48525..7ad6c3ae022 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py @@ -17,7 +17,7 @@ from __future__ import print_function import unittest import paddle -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase paddle.enable_static() @@ -25,12 +25,14 @@ paddle.enable_static() class TestASPHelperPruning2DGreedy(TestASPHelperPruningBase): def test_2D_greedy_inference_pruning(self): - self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_2D_GREEDY, - sparsity.CheckMethod.CHECK_2D) + self.run_inference_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) def test_2D_greedy_training_pruning(self): - self.run_training_pruning_test(sparsity.MaskAlgo.MASK_2D_GREEDY, - sparsity.CheckMethod.CHECK_2D) + self.run_training_pruning_test( + 'mask_2d_greedy', + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py index 387cb55e5c3..4aac878763b 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py @@ -18,22 +18,24 @@ from __future__ import print_function import unittest import threading, time import paddle -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity import numpy as np class TestASPUtils(unittest.TestCase): def test_get_check_method(self): self.assertEqual( - sparsity.CheckMethod.get_checking_method(sparsity.MaskAlgo.MASK_1D), - sparsity.CheckMethod.CHECK_1D) + paddle.fluid.contrib.sparsity.CheckMethod.get_checking_method( + paddle.fluid.contrib.sparsity.MaskAlgo.MASK_1D), + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) self.assertEqual( - sparsity.CheckMethod.get_checking_method( - sparsity.MaskAlgo.MASK_2D_GREEDY), - sparsity.CheckMethod.CHECK_2D) + paddle.fluid.contrib.sparsity.CheckMethod.get_checking_method( + paddle.fluid.contrib.sparsity.MaskAlgo.MASK_2D_GREEDY), + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) self.assertEqual( - sparsity.CheckMethod.get_checking_method( - sparsity.MaskAlgo.MASK_2D_BEST), sparsity.CheckMethod.CHECK_2D) + paddle.fluid.contrib.sparsity.CheckMethod.get_checking_method( + paddle.fluid.contrib.sparsity.MaskAlgo.MASK_2D_BEST), + paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) def test_density(self): x = np.array([[1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], @@ -47,53 +49,59 @@ class TestASPUtils(unittest.TestCase): x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]) - self.assertTrue(sparsity.check_mask_1d(x, 2, 4)) - self.assertFalse(sparsity.check_mask_1d(x, 3, 4)) - self.assertTrue(sparsity.check_mask_1d(x, 2, 5)) - self.assertFalse(sparsity.check_mask_1d(x, 3, 5)) - self.assertTrue(sparsity.check_mask_1d(x, 3, 6)) - self.assertFalse(sparsity.check_mask_1d(x, 4, 6)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_1d(x, 2, 4)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_1d(x, 3, 4)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_1d(x, 2, 5)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_1d(x, 3, 5)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_1d(x, 3, 6)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_1d(x, 4, 6)) def test_get_mask_1d(self): for _ in range(10): x = np.random.randint(10, size=(5, 5)) - x = sparsity.get_mask_1d(x, 2, 4) - self.assertTrue(sparsity.check_mask_1d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_1d(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_1d(x, 2, 4)) x = np.random.randn(5, 4) - x = sparsity.get_mask_1d(x, 2, 4) - self.assertTrue(sparsity.check_mask_1d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_1d(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_1d(x, 2, 4)) def test_check_mask_2d(self): x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0]]) - self.assertTrue(sparsity.check_mask_2d(x, 2, 4)) - self.assertFalse(sparsity.check_mask_2d(x, 3, 4)) - self.assertTrue(sparsity.check_mask_2d(x, 2, 5)) - self.assertFalse(sparsity.check_mask_2d(x, 3, 5)) - self.assertTrue(sparsity.check_mask_2d(x, 3, 6)) - self.assertFalse(sparsity.check_mask_2d(x, 4, 6)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 4)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_2d(x, 3, 4)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 5)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_2d(x, 3, 5)) + self.assertTrue(paddle.fluid.contrib.sparsity.check_mask_2d(x, 3, 6)) + self.assertFalse(paddle.fluid.contrib.sparsity.check_mask_2d(x, 4, 6)) def test_get_mask_2d_greedy(self): for _ in range(10): x = np.random.randint(10, size=(5, 5)) - x = sparsity.get_mask_2d_greedy(x, 2, 4) - self.assertTrue(sparsity.check_mask_2d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_2d_greedy(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 4)) x = np.random.randn(5, 4) - x = sparsity.get_mask_2d_greedy(x, 2, 4) - self.assertTrue(sparsity.check_mask_2d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_2d_greedy(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 4)) def test_get_mask_2d_best(self): for _ in range(10): x = np.random.randint(10, size=(5, 5)) - x = sparsity.get_mask_2d_best(x, 2, 4) - self.assertTrue(sparsity.check_mask_2d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_2d_best(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 4)) x = np.random.randn(5, 4) - x = sparsity.get_mask_2d_best(x, 2, 4) - self.assertTrue(sparsity.check_mask_2d(x, 2, 4)) + x = paddle.fluid.contrib.sparsity.get_mask_2d_best(x, 2, 4) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_mask_2d(x, 2, 4)) def test_threadsafe_valid_2d_patterns(self): def get_reference(m=4, n=2): @@ -160,30 +168,54 @@ class TestASPUtils(unittest.TestCase): self.__test_1D_2D_sparse_mask_generation_methods(x) def __test_1D_2D_sparsity_checking_methods(self, x_2d): - mask = sparsity.get_mask_1d(x_2d, 2, 4) + mask = paddle.fluid.contrib.sparsity.get_mask_1d(x_2d, 2, 4) self.assertEqual( - sparsity.check_sparsity( - mask, func_name=sparsity.CheckMethod.CHECK_1D, n=2, m=4), - sparsity.check_mask_1d(mask, 2, 4)) - mask = sparsity.get_mask_2d_best(x_2d, 2, 4) + paddle.fluid.contrib.sparsity.check_sparsity( + mask, + func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D, + n=2, + m=4), + paddle.fluid.contrib.sparsity.check_mask_1d(mask, 2, 4)) + mask = paddle.fluid.contrib.sparsity.get_mask_2d_best(x_2d, 2, 4) self.assertEqual( - sparsity.check_sparsity( - mask, func_name=sparsity.CheckMethod.CHECK_2D, n=2, m=4), - sparsity.check_mask_2d(mask, 2, 4)) + paddle.fluid.contrib.sparsity.check_sparsity( + mask, + func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, + n=2, + m=4), + paddle.fluid.contrib.sparsity.check_mask_2d(mask, 2, 4)) def __test_1D_2D_sparse_mask_generation_methods(self, x): - mask = sparsity.create_mask( - x, func_name=sparsity.MaskAlgo.MASK_1D, n=2, m=4) + mask = paddle.fluid.contrib.sparsity.create_mask( + x, + func_name=paddle.fluid.contrib.sparsity.MaskAlgo.MASK_1D, + n=2, + m=4) self.assertTrue( - sparsity.check_sparsity( - mask, func_name=sparsity.CheckMethod.CHECK_1D, n=2, m=4)) - mask = sparsity.create_mask( - x, func_name=sparsity.MaskAlgo.MASK_2D_GREEDY, n=2, m=4) + paddle.fluid.contrib.sparsity.check_sparsity( + mask, + func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + mask = paddle.fluid.contrib.sparsity.create_mask( + x, + func_name=paddle.fluid.contrib.sparsity.MaskAlgo.MASK_2D_GREEDY, + n=2, + m=4) self.assertTrue( - sparsity.check_sparsity( - mask, func_name=sparsity.CheckMethod.CHECK_2D, n=2, m=4)) - mask = sparsity.create_mask( - x, func_name=sparsity.MaskAlgo.MASK_2D_BEST, n=2, m=4) + paddle.fluid.contrib.sparsity.check_sparsity( + mask, + func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, + n=2, + m=4)) + mask = paddle.fluid.contrib.sparsity.create_mask( + x, + func_name=paddle.fluid.contrib.sparsity.MaskAlgo.MASK_2D_BEST, + n=2, + m=4) self.assertTrue( - sparsity.check_sparsity( - mask, func_name=sparsity.CheckMethod.CHECK_2D, n=2, m=4)) + paddle.fluid.contrib.sparsity.check_sparsity( + mask, + func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, + n=2, + m=4)) diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py index 34d17f570e4..074aedb9476 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py @@ -20,7 +20,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import os -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') @@ -73,7 +73,7 @@ class TestFleetWithASP(unittest.TestCase): feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) exe.run(startup_prog) - sparsity.prune_model(place, train_prog) + sparsity.prune_model(train_prog) data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) exe.run(train_prog, feed=feeder.feed([data])) @@ -82,7 +82,9 @@ class TestFleetWithASP(unittest.TestCase): if ASPHelper._is_supported_layer(train_prog, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) - self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py index c4074b2ae7a..a34d7e69872 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py @@ -20,7 +20,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import os -from paddle.fluid.contrib import sparsity +from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') @@ -76,7 +76,7 @@ class TestFleetWithASP(unittest.TestCase): optimizer.amp_init(place) - sparsity.prune_model(place, train_prog) + sparsity.prune_model(train_prog) data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) exe.run(train_prog, feed=feeder.feed([data])) @@ -85,7 +85,9 @@ class TestFleetWithASP(unittest.TestCase): if ASPHelper._is_supported_layer(train_prog, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) - self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) def test_with_asp_and_pure_fp16(self): fleet.init(is_collective=True) @@ -114,7 +116,7 @@ class TestFleetWithASP(unittest.TestCase): optimizer.amp_init(place) - sparsity.prune_model(place, train_prog) + sparsity.prune_model(train_prog) data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) exe.run(train_prog, feed=feeder.feed([data])) @@ -123,7 +125,9 @@ class TestFleetWithASP(unittest.TestCase): if ASPHelper._is_supported_layer(train_prog, param.name): mat = np.array(fluid.global_scope().find_var(param.name) .get_tensor()) - self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) if __name__ == "__main__": diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 0f463b0c7d9..c4c2f27146f 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,7 @@ # limitations under the License. from . import amp # noqa: F401 +from . import sparsity # noqa: F401 from . import nn # noqa: F401 from .io import save_inference_model # noqa: F401 from .io import load_inference_model # noqa: F401 diff --git a/python/paddle/static/sparsity/__init__.py b/python/paddle/static/sparsity/__init__.py new file mode 100644 index 00000000000..59f794ef28a --- /dev/null +++ b/python/paddle/static/sparsity/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...fluid.contrib.sparsity import calculate_density #noqa: F401 +from ...fluid.contrib.sparsity import decorate #noqa: F401 +from ...fluid.contrib.sparsity import prune_model #noqa: F401 +from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 +from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 + +__all__ = [ #noqa + 'calculate_density', + 'decorate', + 'prune_model', + 'set_excluded_layers', + 'reset_excluded_layers' +] diff --git a/python/setup.py.in b/python/setup.py.in index 03b0555c965..bdcdd26815e 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -314,6 +314,7 @@ packages=['paddle', 'paddle.static', 'paddle.static.nn', 'paddle.static.amp', + 'paddle.static.sparsity', 'paddle.tensor', 'paddle.onnx', 'paddle.autograd', -- GitLab