diff --git a/python/paddle/fluid/contrib/sparsity/asp.py b/python/paddle/fluid/contrib/sparsity/asp.py index 30439ad736d26f3086a7f87d591aa68a59b7baa8..c366af7237d1bc1a47bcfa952e77397a0d82b7c1 100644 --- a/python/paddle/fluid/contrib/sparsity/asp.py +++ b/python/paddle/fluid/contrib/sparsity/asp.py @@ -20,12 +20,13 @@ import os import copy import numpy as np import paddle +from paddle.fluid.framework import dygraph_only from paddle.fluid import global_scope, program_guard, layers from paddle.fluid.initializer import ConstantInitializer from paddle.fluid.contrib import sparsity +from paddle.fluid import core from paddle.fluid.contrib.sparsity.supported_layer_list import supported_layers_and_prune_func_map from paddle.fluid.contrib.sparsity.supported_layer_list import _default_pruning -from paddle.fluid import core OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() @@ -35,45 +36,90 @@ __all__ = [ ] -def set_excluded_layers(main_program, param_names): +def set_excluded_layers(param_names, main_program=None): r""" Set parameter name of layers which would not be pruned as sparse weights. Args: + param_names (list of string): A list contains names of parameters. main_program (Program, optional): Program with model definition and its parameters. - param_names (list): A list contains names of parameters. + If None is given, then it would be set as `paddle.static.default_main_program(). + Default is None. Examples: - .. code-block:: python - - import paddle - from paddle.static import sparsity - - paddle.enable_static() - - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - - with paddle.static.program_guard(main_program, startup_program): - input_data = paddle.static.data(name='data', shape=[None, 128]) - label = paddle.static.data(name='label', shape=[None, 10]) - hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc") - hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc") - prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) - loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - - # Setup exluded layers out from ASP workflow. - # Please note, excluded_layers must be set before calling `optimizer.minimize()`. - sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) - - optimizer = paddle.optimizer.SGD(learning_rate=0.1) - optimizer = paddle.static.amp.decorate(optimizer ) - # Calling sparsity.decorate() to wrap minimize() in optimizer, which - # will insert necessary masking operations for ASP workflow. - optimizer = sparsity.decorate(optimizer) - optimizer.minimize(loss, startup_program) + 1. Usage of Dynamic Graph + + .. code-block:: python + + import paddle + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 100) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + prediction = self.linear1(hidden) + return prediction + + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) + + optimizer = paddle.incubate.asp.decorate(optimizer) + + 2. Usage of Static Graph + + .. code-block:: python + + import paddle + + paddle.enable_static() + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 100) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + prediction = self.linear1(hidden) + return prediction + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + label = paddle.static.data(name='label', shape=[None, 100]) + my_layer = MyLayer() + prob = my_layer(input_data) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + # Setup exluded layers out from ASP workflow. + # Please note, excluded_layers must be set before calling optimizer.minimize(). + paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) + + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + optimizer = paddle.static.amp.decorate(optimizer ) + # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + # will insert necessary masking operations for ASP workflow. + optimizer = paddle.incubate.asp.decorate(optimizer) + optimizer.minimize(loss, startup_program) """ + if main_program is None: + main_program = paddle.static.default_main_program() ASPHelper.set_excluded_layers( - main_program=main_program, param_names=param_names) + param_names=param_names, main_program=main_program) def reset_excluded_layers(main_program=None): @@ -83,153 +129,310 @@ def reset_excluded_layers(main_program=None): Args: main_program (Program, optional): Program with model definition and its parameters. - Examples: - .. code-block:: python + If None is given, then this function would reset all excluded_layers. + Default is None. + Examples: + 1. Usage of Dynamic Graph - import paddle - from paddle.static import sparsity + .. code-block:: python - paddle.enable_static() + import paddle - main_program = paddle.static.Program() - startup_program = paddle.static.Program() + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 100) - with paddle.static.program_guard(main_program, startup_program): - input_data = paddle.static.data(name='data', shape=[None, 128]) - label = paddle.static.data(name='label', shape=[None, 10]) - hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="my_first_fc") - hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="my_second_fc") - prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) - loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + prediction = self.linear1(hidden) + return prediction - # Setup exluded layers out from ASP workflow. - # Please note, excluded_layers must be set before calling `optimizer.minimize()`. - sparsity.set_excluded_layers(main_program, ["my_second_fc"]) - # Now the weights of "my_second_fc" would not be included in Automatic SParsity's workflow. + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Need to set excluded layers before calling decorate + paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) + # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. + # Please note, reset_excluded_layers also must be called before calling sparsity.decorate(). + paddle.incubate.asp.reset_excluded_layers() + + optimizer = paddle.incubate.asp.decorate(optimizer) + + 2. Usage of Static Graph + + .. code-block:: python - # Reset excluded_layers, all FC layers would be included into Automatic SParsity's workflow. - # Please note, reset_excluded_layers also must be called before calling `optimizer.minimize()`. - sparsity.reset_excluded_layers(main_program) + import paddle + + paddle.enable_static() + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 100) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + prediction = self.linear1(hidden) + return prediction + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + label = paddle.static.data(name='label', shape=[None, 100]) + my_layer = MyLayer() + prob = my_layer(input_data) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + # Setup exluded layers out from ASP workflow. + # Please note, excluded_layers must be set before calling optimizer.minimize(). + paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) + # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. + # Please note, reset_excluded_layers also must be called before calling optimizer.minimize(). + paddle.incubate.asp.reset_excluded_layers(main_program) + + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + optimizer = paddle.static.amp.decorate(optimizer ) + # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + # will insert necessary masking operations for ASP workflow. + optimizer = paddle.incubate.asp.decorate(optimizer) + optimizer.minimize(loss, startup_program) """ ASPHelper.reset_excluded_layers(main_program=main_program) def decorate(optimizer): r""" - Wrap the given optimizer as a OptimizerWithSparsityGuarantee, - which would insert necessary ops for ASP workflows when calling minimize() + Wrap the given optimizer as a OptimizerWithSparsityGuarantee, + If runnig with dynamic graph mode. ASP would creates mask variables for supported parameters. + Else if in static graph mode, ASP would creates mask variables and inserts necessary ops + when calling minimize() Args: optimizer (Optimizer): A Optimizer used for training. Returns: OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer. Examples: - .. code-block:: python + 1. Usage of Dynamic Graph - import paddle - from paddle.static import sparsity + .. code-block:: python + + import paddle + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) - main_program = paddle.static.Program() - startup_program = paddle.static.Program() + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction - paddle.enable_static() + my_layer = MyLayer() + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) - with paddle.static.program_guard(main_program, startup_program): - input_data = paddle.static.data(name='data', shape=[None, 128]) - label = paddle.static.data(name='label', shape=[None, 10]) - hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None) - prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) - loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which + # will apply necessary masking operations for ASP workflow. + # In dynamic graph mode, ASP would create related mask variables during decoration. + optimizer = paddle.incubate.asp.decorate(optimizer) - optimizer = paddle.optimizer.SGD(learning_rate=0.1) - optimizer = sparsity.decorate(optimizer) - # if do sparse training with Fleet, please replace above decorate with: - # strategy = paddle.distributed.fleet.DistributedStrategy() - # strategy.asp = True - # optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + 2. Usage of Static Graph - optimizer.minimize(loss, startup_program) + .. code-block:: python + + import paddle + + paddle.enable_static() + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 100) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + prediction = self.linear1(hidden) + return prediction + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + label = paddle.static.data(name='label', shape=[None, 100]) + my_layer = MyLayer() + prob = my_layer(input_data) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + # will insert necessary masking operations for ASP workflow. + # In static graph mode, ASP creates related mask variables + # during minimize(). + optimizer = paddle.incubate.asp.decorate(optimizer) + optimizer.minimize(loss, startup_program) """ return ASPHelper.decorate(optimizer) -def prune_model(main_program=None, - n=2, - m=4, - mask_algo='mask_1d', - with_mask=True): +def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True): r""" - Pruning parameters of supported layers in :attr:`main_program` via + Pruning parameters of supported layers in :attr:`model` via specified mask generation function given by :attr:`mask_algo`. This function supports both training and inference controlled by :attr:`with_mask`. If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables, else only prunes parameters. - *Note*: If parameters are supported and in FP16, please set :attr:`n`=2, :attr:`m`=4, - if they in FP32, then :attr:`n`=1, :attr:`m`=2` to further enable Sparse Tensor Core acceleration. - - *Note*: If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` + *Note*: (Static graph mode) If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable). Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for - inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. + inference only. To obtain OptimizerWithSparsityGuarantee, please see `paddle.incubate.asp.decoreate()`. Args: - main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() - n (int): n of `n:m` sparse pattern. - m (int): m of `n:m` sparse pattern. + model (Program|nn.Layer): Program with model definition and its parameters, or a object of `paddle.nn.Layer`. + n (int, optional): n of `n:m` sparse pattern. Default is 2. + m (int, optional): m of `n:m` sparse pattern. Default is 4. mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`. The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'. with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True. Returns: dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. Examples: - .. code-block:: python - - import paddle - from paddle.static import sparsity - - paddle.enable_static() - - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - - with paddle.static.program_guard(main_program, startup_program): - input_data = paddle.static.data(name='data', shape=[None, 128]) - label = paddle.static.data(name='label', shape=[None, 10]) - hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc") - hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc") - prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None) - loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - - # Setup exluded layers out from ASP workflow. - # Please note, excluded_layers must be set before calling `optimizer.minimize()`. - sparsity.set_excluded_layers(main_program, ["need_dense_fc"]) + 1. Usage of Dynamic Graph - optimizer = paddle.optimizer.SGD(learning_rate=0.1) - optimizer = paddle.static.amp.decorate(optimizer ) - # Calling sparsity.decorate() to wrap minimize() in optimizer, which - # will insert necessary masking operations for ASP workflow. - optimizer = sparsity.decorate(optimizer) - optimizer.minimize(loss, startup_program) + .. code-block:: python - device = paddle.device.get_device() - place = paddle.set_device(device) + import paddle + import numpy as np + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + my_layer = MyLayer() + loss_fn = paddle.nn.MSELoss(reduction='mean') + + optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=my_layer.parameters()) + + # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which + # will apply necessary masking operations for ASP workflow. + # In dynamic graph mode, ASP would create related mask variables during decoration. + optimizer = paddle.incubate.asp.decorate(optimizer) + + # Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model() + paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') + + for i in range(10): + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint(10, size=(64, 1)), + dtype='float32', stop_gradient=False) + output = my_layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + 2. Usage of Static Graph - exe = paddle.static.Executor(place) - exe.run(startup_program) + .. code-block:: python - # Must call `exe.run(startup_program)` first before calling `sparsity.prune_model` - sparsity.prune_model(main_program, mask_algo='mask_2d_best') + import paddle + import numpy as np + + paddle.enable_static() + + class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(main_program, startup_program): + input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32]) + label = paddle.static.data(name='label', shape=[None, 1]) + my_layer = MyLayer() + prob = my_layer(input_data) + loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + + optimizer = paddle.optimizer.SGD(learning_rate=0.1) + # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + # will insert necessary masking operations for ASP workflow. + # In static graph mode, ASP creates related mask variables + # during minimize(). + optimizer = paddle.incubate.asp.decorate(optimizer) + optimizer.minimize(loss, startup_program) + + device = paddle.device.get_device() + place = paddle.set_device(device) + + exe = paddle.static.Executor(place) + exe.run(startup_program) + + # Must call exe.run(startup_program) first before calling paddle.asp.prune_model() + paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') + # it also be accepted to call + # paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best') + + for i in range(10): + imgs = np.random.randn(64, 3, 32, 32).astype('float32') + labels = np.random.randint(10, size=(64, 1)).astype('float32') + exe.run(main_program, feed={'data':imgs, 'label':labels}) """ - if main_program is not None and hasattr( - main_program, - "distributed_info_") and main_program.distributed_info_[ - "sharding_degree"] > 1 and paddle.fluid.is_compiled_with_cuda(): - gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) - place = paddle.CUDAPlace(gpu_id) - else: - device = paddle.device.get_device() - place = paddle.set_device(device) + device = paddle.device.get_device() + place = paddle.set_device(device) MaskAlgo_mapping = { 'mask_1d': sparsity.MaskAlgo.MASK_1D, @@ -237,11 +440,26 @@ def prune_model(main_program=None, 'mask_2d_best': sparsity.MaskAlgo.MASK_2D_BEST } assert (mask_algo in MaskAlgo_mapping), \ - 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' + 'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]' + + prune_func = None + if isinstance(model, paddle.nn.Layer): + prune_func = ASPHelper.prune_model_by_layer + elif isinstance(model, paddle.static.Program): + prune_func = ASPHelper.prune_model_by_program + if hasattr(model, "distributed_info_") and \ + model.distributed_info_["sharding_degree"] > 1 and \ + paddle.fluid.is_compiled_with_cuda(): + gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + place = paddle.CUDAPlace(gpu_id) + else: + raise TypeError( + "model should be paddle.nn.Layer or paddle.static.Program, but got {}". + format(type(model))) - return ASPHelper.prune_model( - place=place, - main_program=main_program, + return prune_func( + place, + model, n=n, m=m, mask_algo=MaskAlgo_mapping[mask_algo], @@ -300,7 +518,7 @@ class ASPHelper(object): __asp_info = {} @classmethod - def set_excluded_layers(cls, main_program, param_names): + def set_excluded_layers(cls, param_names, main_program): r""" This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`. """ @@ -313,8 +531,8 @@ class ASPHelper(object): This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`. """ if main_program is None: - for asp_info in cls.__asp_info: - asp_info.reset_excluded_layers() + for prog in cls.__asp_info: + cls.__asp_info[prog].reset_excluded_layers() else: cls._get_program_asp_info(main_program).reset_excluded_layers() @@ -323,16 +541,25 @@ class ASPHelper(object): r""" This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`. """ + if paddle.in_dynamic_mode(): + # main_prog and startup_prog would be used with paddle.static.program_guard + # to create ASP masks. Moreover, main_prog is a key to map paddle.static.Program + # to its own ASP informantion, like ASP mask variables. For dynamic graph, we use + # default_main_program as the key. + main_prog = paddle.static.default_main_program() + startup_prog = paddle.static.default_startup_program() + ASPHelper._create_mask_variables(main_prog, startup_prog, + optimizer._parameter_list) return OptimizerWithSparsityGuarantee(optimizer) @classmethod - def prune_model(cls, - place, - main_program=None, - n=2, - m=4, - mask_algo=sparsity.MaskAlgo.MASK_1D, - with_mask=True): + def prune_model_by_program(cls, + place, + main_program=None, + n=2, + m=4, + mask_algo=sparsity.MaskAlgo.MASK_1D, + with_mask=True): r""" This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. """ @@ -366,9 +593,63 @@ class ASPHelper(object): np.array(weight_mask_tensor).dtype) weight_mask_tensor.set(weight_sparse_mask, place) asp_info.update_masks(param.name, weight_sparse_mask) - return asp_info.masks.copy() + @classmethod + def prune_model_by_layer(cls, + place, + layer, + n=2, + m=4, + mask_algo=sparsity.MaskAlgo.MASK_1D, + with_mask=True): + r""" + This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. + """ + if paddle.in_dynamic_mode(): + main_program = paddle.static.default_main_program() + asp_info = cls._get_program_asp_info(main_program) + + for param in layer.parameters(): + if ASPHelper._is_supported_layer(main_program, param.name): + weight_nparray = param.numpy() + + prune_func = ASPHelper._get_prune_func_by_name(param.name) + + weight_pruned_nparray, weight_sparse_mask = \ + prune_func(weight_nparray, m, n, mask_algo, param.name) + + weight_pruned_nparray = weight_pruned_nparray.astype( + weight_nparray.dtype) + param.set_value(weight_pruned_nparray) + + if with_mask: + weight_mask_param = asp_info.mask_vars.get(param.name, + None) + assert weight_mask_param is not None, \ + 'Cannot find {} variable, please call sparsity.decorate() to' \ + ' decorate your optimizer first!'.format(ASPHelper._get_mask_name(param.name)) + weight_mask_param.set_value(weight_sparse_mask) + + asp_info.update_masks(param.name, weight_sparse_mask) + + return asp_info.masks.copy() + else: + # This for loop is only used to obtain Block and Program from + # first parameters. + target_program = None + for param in layer.parameters(): + target_program = param.block.program + assert target_program is not None, \ + 'Cannot get paddle.static.Program from Paddle.nn.Layer.' + return ASPHelper.prune_model_by_program( + place, + target_program, + n=n, + m=m, + mask_algo=mask_algo, + with_mask=with_mask) + @staticmethod def _get_mask_name(param_name): r""" @@ -393,13 +674,15 @@ class ASPHelper(object): """ var_list = [] for param in main_program.global_block().all_parameters(): - if ASPHelper.MASK_APPENDDED_NAME not in param.name: + param_name_list = param.name.split('.') + + if ASPHelper.MASK_APPENDDED_NAME not in param_name_list: var_list.append(param) return var_list @classmethod def _get_program_asp_info(cls, main_program): - if not main_program in cls.__asp_info: + if main_program not in cls.__asp_info: cls.__asp_info[main_program] = ProgramASPInfo() return cls.__asp_info[main_program] @@ -508,14 +791,37 @@ class ASPHelper(object): optimizer_ops, params_and_grads = optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) - cls._create_mask_variables(main_program, startup_program, - params_and_grads) - cls._insert_sparse_mask_ops(main_program, params_and_grads) + + params_only = [pg[0] for pg in params_and_grads] + cls._create_mask_variables(main_program, startup_program, params_only) + cls._insert_sparse_mask_ops(main_program, params_only) return optimizer_ops, params_and_grads @classmethod - def _create_mask_variables(cls, main_program, startup_program, - params_and_grads): + @dygraph_only + def _step(cls, optimizer): + r""" + This function is a decorator of `step` function in `Optimizer`. + There are three steps: + + 1. Call :attr:`optimizer`.step() + 2. Mask parameters with sparse masks. + + *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. + (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph + cannot be modified anymore.) + + Args: + optimizer (Optimizer): A Optimizer used for training. + """ + optimizer.step() + main_prog = paddle.static.default_main_program() + with paddle.fluid.dygraph.no_grad(): + ASPHelper._insert_sparse_mask_ops(main_prog, + optimizer._parameter_list) + + @classmethod + def _create_mask_variables(cls, main_program, startup_program, params): r""" Create sparse mask Tensors according to supported layers in :attr:`main_program`. This function is called in second step of `ASPHelper._minimize` @@ -523,48 +829,45 @@ class ASPHelper(object): Args: main_program (Program): Program with model definition and its parameters. startup_program (Program): Program for initializing parameters. - params_and_grads (list): Variable pairs of parameters and their gradients. + params (list): Variable parameters. """ asp_info = cls._get_program_asp_info(main_program) with program_guard(main_program, startup_program): - for param_and_grad in params_and_grads: - if ASPHelper._is_supported_layer(main_program, - param_and_grad[0].name): - mask_param = layers.create_parameter( - name=ASPHelper._get_mask_name(param_and_grad[0].name), - shape=param_and_grad[0].shape, - dtype=param_and_grad[0].dtype, - default_initializer=ConstantInitializer(value=1.0)) - mask_param.stop_gradient = True - mask_param.trainable = False - asp_info.update_mask_vars(param_and_grad[0].name, - mask_param) + for param in params: + if ASPHelper._is_supported_layer(main_program, param.name): + if param.name not in asp_info.mask_vars: + mask_param = layers.create_parameter( + name=ASPHelper._get_mask_name(param.name), + shape=param.shape, + dtype=param.dtype, + default_initializer=ConstantInitializer(value=1.0)) + mask_param.stop_gradient = True + mask_param.trainable = False + asp_info.update_mask_vars(param.name, mask_param) @classmethod - def _insert_sparse_mask_ops(cls, main_program, param_grads): + def _insert_sparse_mask_ops(cls, main_program, params): r""" Insert masking ops in the end of parameters update. This function is called in third step of `ASPHelper._minimize` Args: main_program (Program): Program with model definition and its parameters. - params_and_grads (list): Variable pairs of parameters and their gradients. + params (list): Variable parameters. """ block = main_program.global_block() asp_info = cls._get_program_asp_info(main_program) - for param_grad in param_grads: - if param_grad[0].name in asp_info.mask_vars: + for param in params: + if param.name in asp_info.mask_vars: block.append_op( type='elementwise_mul', - inputs={ - "X": param_grad[0], - 'Y': asp_info.mask_vars[param_grad[0].name] - }, - outputs={'Out': param_grad[0]}, + inputs={"X": param, + 'Y': asp_info.mask_vars[param.name]}, + outputs={'Out': param}, attrs={ 'axis': -1, 'use_mkldnn': False, - OP_ROLE_KEY: OpRole.Optimize + OP_ROLE_KEY: int(OpRole.Optimize) }) @@ -579,8 +882,9 @@ class OptimizerWithSparsityGuarantee(object): def __init__(self, optimizer): self._optimizer = optimizer - self._learning_rate = optimizer._learning_rate - self._learning_rate_map = optimizer._learning_rate_map + + def __getattr__(self, item): + return getattr(self._optimizer, item) def minimize(self, loss, @@ -605,3 +909,55 @@ class OptimizerWithSparsityGuarantee(object): startup_program=startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set) + + @dygraph_only + def step(self): + r""" + This function is a decorator of `step` function in `Optimizer`. + There are three steps: + + 1. Call :attr:`optimizer`.step() + 2. Mask parameters with sparse masks. + + *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. + (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph + cannot be modified anymore.) + + Args: + optimizer (Optimizer): A Optimizer used for training. + """ + ASPHelper._step(self._optimizer) + + @dygraph_only + def state_dict(self): + r""" + This function is a decorator of `state_dict` function in `Optimizer`. + + Returns: + state_dict(dict) : dict contains all the Tensor used by optimizer + """ + state_dict = self._optimizer.state_dict() + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name, var in asp_info.mask_vars.items(): + state_dict.update({ASPHelper._get_mask_name(param_name): var}) + return state_dict + + @dygraph_only + def set_state_dict(self, state_dict): + r""" + This function is a decorator of `set_state_dict` function in `Optimizer`. + Args: + state_dict(dict) : Dict contains all the Tensor needed by optimizer + Return: + None + """ + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name, var in asp_info.mask_vars.items(): + param_mask_name = ASPHelper._get_mask_name(param_name) + assert param_mask_name in state_dict, \ + "The {} is not found.".format(param_mask_name) + var.set_value(state_dict[param_mask_name]) + asp_info.update_masks(param_name, var.numpy()) + return self._optimizer.set_state_dict(state_dict) diff --git a/python/paddle/fluid/contrib/sparsity/utils.py b/python/paddle/fluid/contrib/sparsity/utils.py index 8b8c043bc4bad71a92e71fac67b65383fe13b013..a28f7fc2b4ed67b562f9ba07e79d2bf551846242 100644 --- a/python/paddle/fluid/contrib/sparsity/utils.py +++ b/python/paddle/fluid/contrib/sparsity/utils.py @@ -94,13 +94,12 @@ def calculate_density(x): float: The density of :attr:`x`. Examples: .. code-block:: python - + import paddle import numpy as np - import paddle.static.sparsity as sparsity x = np.array([[0, 1, 3, 0], [1, 1, 0, 1]]) - sparsity.calculate_density(x) # 0.625 + paddle.incubate.asp.calculate_density(x) # 0.625 """ x_flattened = x.flatten() return float(np.nonzero(x_flattened)[0].size) / x_flattened.size diff --git a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt index b6b313465ab20a6d7683a7d11dc9fc7eeda373c4..76856d88e17899df0b5dc3321fe4b8401e94297f 100644 --- a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt @@ -1,8 +1,8 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp") -list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_amp") +list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_static") +list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_dynamic") list(REMOVE_ITEM TEST_OPS "test_fleet_with_asp_sharding") foreach(TEST_OP ${TEST_OPS}) @@ -10,9 +10,9 @@ foreach(TEST_OP ${TEST_OPS}) endforeach(TEST_OP) if(WITH_DISTRIBUTE) - py_test_modules(test_fleet_with_asp MODULES test_fleet_with_asp ENVS ${dist_ENVS}) if (WITH_GPU OR WITH_XPU OR WITH_ASCEND OR WITH_ASCEND_CL) - py_test_modules(test_fleet_with_asp_amp MODULES test_fleet_with_asp_amp ENVS ${dist_ENVS}) + py_test_modules(test_fleet_with_asp_dynamic MODULES test_fleet_with_asp_dynamic ENVS ${dist_ENVS}) + py_test_modules(test_fleet_with_asp_static MODULES test_fleet_with_asp_static ENVS ${dist_ENVS}) endif() endif() @@ -21,3 +21,8 @@ if((WITH_DISTRIBUTE) AND (NOT WIN32) AND (NOT APPLE)) py_test_modules(test_fleet_with_asp_sharding MODULES test_fleet_with_asp_sharding ENVS ${dist_ENVS}) endif() endif() + +set_tests_properties(test_asp_pruning_dynamic PROPERTIES TIMEOUT 30) +set_tests_properties(test_asp_pruning_static PROPERTIES TIMEOUT 30) +set_tests_properties(test_asp_optimize_dynamic PROPERTIES TIMEOUT 30) +set_tests_properties(test_asp_optimize_static PROPERTIES TIMEOUT 30) diff --git a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py index d41a7b2b842e800bdf762f7d92e9d0b8b32f6acb..e594bc5c34eb35d63f9b0210bb17bec7cbea9de2 100644 --- a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py +++ b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py @@ -20,7 +20,6 @@ import threading, time import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np @@ -60,7 +59,7 @@ class TestASPHelperPruningBase(unittest.TestCase): loss = fluid.layers.mean( fluid.layers.cross_entropy( input=self.predict, label=self.label)) - optimizer = sparsity.decorate( + optimizer = paddle.incubate.asp.decorate( fluid.optimizer.SGD(learning_rate=0.01)) optimizer.minimize(loss, self.startup_program) @@ -75,7 +74,7 @@ class TestASPHelperPruningBase(unittest.TestCase): def __pruning_and_checking(self, exe, place, mask_func_name, check_func_name, with_mask): exe.run(self.startup_program) - sparsity.prune_model( + paddle.incubate.asp.prune_model( self.main_program, mask_algo=mask_func_name, with_mask=with_mask) for param in self.main_program.global_block().all_parameters(): if ASPHelper._is_supported_layer(self.main_program, param.name): diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py b/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py index a2b499a9e01c36eefcb9b6cb91956abc5ee0a99b..dca56076dbcebe7be2bfb9ffcce60a446a06155f 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_customized_pruning.py @@ -66,6 +66,97 @@ class TestASPAddSupportedLayer(unittest.TestCase): my_own_layer_name in supported_layers_and_prune_func_map) +class TestASPDynamicCustomerizedPruneFunc(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + class CustomerLayer(paddle.nn.Layer): + def __init__(self): + super(CustomerLayer, self).__init__() + + self.weight = self.create_parameter( + shape=[32, 32], attr=None, dtype='float32', is_bias=False) + self.linear1 = paddle.nn.Linear(32, 32) + self.linear2 = paddle.nn.Linear(32, 10) + + def forward(self, input_): + hidden = paddle.nn.functional.linear( + x=input_, weight=self.weight) + hidden = self.linear1(hidden) + out = self.linear2(hidden) + return out + + sparsity.add_supported_layer(CustomerLayer, my_own_pruning) + + self.layer = CustomerLayer() + self.customer_prefix = paddle.fluid.dygraph.layers._convert_camel_to_snake( + CustomerLayer.__name__) + self.supported_layer_count_ref = 3 + + def test_inference_pruning(self): + + sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=False) + + supported_layer_count = 0 + for param in self.layer.parameters(): + mat = param.numpy() + + if sparsity.asp.ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + def test_training_pruning(self): + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=self.layer.parameters()) + optimizer = sparsity.decorate(optimizer) + + sparsity.prune_model(self.layer, mask_algo="mask_1d", with_mask=True) + + supported_layer_count = 0 + for param in self.layer.parameters(): + mat = param.numpy() + + if sparsity.asp.ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + + mat_mask = sparsity.asp.ASPHelper._get_program_asp_info( + paddle.static.default_main_program()).mask_vars[ + param.name].numpy() + + supported_layer_count += 1 + if (self.customer_prefix in param.name): + self.assertLessEqual( + np.sum(mat.flatten() - static_tensor.flatten()), 1e-4) + self.assertLessEqual( + np.sum(mat_mask.flatten() - static_tensor_mask.flatten( + )), 1e-4) + else: + self.assertTrue( + sparsity.check_sparsity( + mat.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertTrue( + sparsity.check_sparsity( + mat_mask.T, + func_name=sparsity.CheckMethod.CHECK_1D, + n=2, + m=4)) + self.assertEqual(supported_layer_count, self.supported_layer_count_ref) + + class TestASPStaticCustomerizedPruneFunc(unittest.TestCase): def setUp(self): paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..e127dca22511665548b105d991e9f0d043846043 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_dynamic.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=2, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(1352, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestASPDynamicOptimize(unittest.TestCase): + def setUp(self): + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_is_supported_layers(self): + program = paddle.static.default_main_program() + + names = [ + 'embedding_0.w_0', 'fack_layer_0.w_0', 'conv2d_0.w_0', + 'conv2d_0.b_0', 'conv2d_1.w_0', 'conv2d_1.b_0', 'fc_0.w_0', + 'fc_0.b_0', 'fc_1.w_0', 'fc_1.b_0', 'linear_2.w_0', 'linear_2.b_0' + ] + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + paddle.incubate.asp.set_excluded_layers(['fc_1', 'conv2d_0']) + ref = [ + False, False, False, False, True, False, True, False, False, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + paddle.incubate.asp.reset_excluded_layers() + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + def test_decorate(self): + param_names = [param.name for param in self.layer.parameters()] + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + + program = paddle.static.default_main_program() + + for name in param_names: + mask_var = ASPHelper._get_program_asp_info(program).mask_vars.get( + name, None) + if ASPHelper._is_supported_layer(program, name): + self.assertTrue(mask_var is not None) + else: + self.assertTrue(mask_var is None) + + def test_asp_training(self): + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + + paddle.incubate.asp.prune_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(32, 3, 24, 24), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(32, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + def test_asp_training_with_amp(self): + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + + paddle.incubate.asp.prune_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(32, 3, 24, 24), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(32, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + with paddle.amp.auto_cast(enable=True): + output = self.layer(imgs) + loss = loss_fn(output, labels) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(self.optimizer, scaled) + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py similarity index 89% rename from python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py rename to python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py index 9e5e3c924f1a507a9b71ff68027283532824e9fb..b51e28cdcb9fc9ab15e97106a4379c3356991617 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize_static.py @@ -1,5 +1,5 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,21 +20,20 @@ import threading, time import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.static import sparsity from paddle.fluid.contrib.sparsity.asp import ASPHelper import numpy as np paddle.enable_static() -class TestASPHelper(unittest.TestCase): +class TestASPStaticOptimize(unittest.TestCase): def setUp(self): self.main_program = fluid.Program() self.startup_program = fluid.Program() def build_model(): img = fluid.data( - name='img', shape=[None, 3, 32, 32], dtype='float32') + name='img', shape=[None, 3, 24, 24], dtype='float32') label = fluid.data(name='label', shape=[None, 1], dtype='int64') hidden = fluid.layers.conv2d( input=img, num_filters=4, filter_size=3, padding=2, act="relu") @@ -87,7 +86,7 @@ class TestASPHelper(unittest.TestCase): self.assertTrue( ref[i] == ASPHelper._is_supported_layer(program, name)) - sparsity.set_excluded_layers(program, ['fc_1', 'conv2d_0']) + paddle.incubate.asp.set_excluded_layers(['fc_1', 'conv2d_0'], program) ref = [ False, False, False, False, True, False, True, False, False, False, True, False @@ -96,7 +95,7 @@ class TestASPHelper(unittest.TestCase): self.assertTrue( ref[i] == ASPHelper._is_supported_layer(program, name)) - sparsity.reset_excluded_layers(program) + paddle.incubate.asp.reset_excluded_layers(program) ref = [ False, False, True, False, True, False, True, False, True, False, True, False @@ -109,7 +108,7 @@ class TestASPHelper(unittest.TestCase): param_names = self.__get_param_names(self.main_program.global_block() .all_parameters()) with fluid.program_guard(self.main_program, self.startup_program): - self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) self.optimizer.minimize(self.loss, self.startup_program) param_names_after_minimize = self.__get_param_names( self.main_program.global_block().all_parameters()) @@ -119,7 +118,7 @@ class TestASPHelper(unittest.TestCase): def test_asp_training(self): with fluid.program_guard(self.main_program, self.startup_program): - self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) self.optimizer.minimize(self.loss, self.startup_program) place = paddle.CPUPlace() @@ -129,10 +128,10 @@ class TestASPHelper(unittest.TestCase): feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place) exe.run(self.startup_program) - sparsity.prune_model(self.main_program) + paddle.incubate.asp.prune_model(self.main_program) - data = (np.random.randn(64, 3, 32, 32), np.random.randint( - 10, size=(64, 1))) + data = (np.random.randn(32, 3, 24, 24), np.random.randint( + 10, size=(32, 1))) exe.run(self.main_program, feed=feeder.feed([data])) for param in self.main_program.global_block().all_parameters(): @@ -149,7 +148,7 @@ class TestASPHelper(unittest.TestCase): with fluid.program_guard(self.main_program, self.startup_program): self.optimizer = fluid.contrib.mixed_precision.decorator.decorate( self.optimizer) - self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) self.optimizer.minimize(self.loss, self.startup_program) exe = fluid.Executor(place) @@ -157,10 +156,10 @@ class TestASPHelper(unittest.TestCase): feed_list=[self.img, self.label], place=place) exe.run(self.startup_program) - sparsity.prune_model(self.main_program) + paddle.incubate.asp.prune_model(self.main_program) - data = (np.random.randn(64, 3, 32, 32), np.random.randint( - 10, size=(64, 1))) + data = (np.random.randn(32, 3, 24, 24), np.random.randint( + 10, size=(32, 1))) exe.run(self.main_program, feed=feeder.feed([data])) for param in self.main_program.global_block().all_parameters(): diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py deleted file mode 100644 index e99509187038c7c6f7737fc109633ada3d3f0da0..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import paddle -import unittest -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning2DBest(TestASPHelperPruningBase): - def test_2D_best_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - def test_2D_best_training_pruning(self): - self.run_training_pruning_test( - 'mask_2d_best', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py deleted file mode 100644 index 7ad6c3ae0227581006858e56cf7930e5775b353d..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import paddle -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning2DGreedy(TestASPHelperPruningBase): - def test_2D_greedy_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_2d_greedy', - paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - def test_2D_greedy_training_pruning(self): - self.run_training_pruning_test( - 'mask_2d_greedy', - paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fad0b64002a0bf9d74158f7b8879734d5b08d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_dynamic.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.fluid import core +from paddle.fluid.contrib.sparsity.asp import ASPHelper + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=2, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(1352, 32) + self.linear2 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + prediction = self.linear2(hidden) + return prediction + + +class TestASPDynamicPruningBase(unittest.TestCase): + def setUp(self): + self.layer = MyLayer() + + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + + self.img = paddle.to_tensor( + np.random.uniform( + low=-0.5, high=0.5, size=(32, 3, 24, 24)), + dtype=np.float32, + place=place, + stop_gradient=False) + + self.set_config() + + def set_config(self): + self.mask_gen_func = 'mask_1d' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D + + def test_inference_pruning(self): + self.__pruning_and_checking(False) + + def test_training_pruning(self): + + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=self.layer.parameters()) + optimizer = paddle.incubate.asp.decorate(optimizer) + + self.__pruning_and_checking(True) + + def __pruning_and_checking(self, with_mask): + + paddle.incubate.asp.prune_model( + self.layer, mask_algo=self.mask_gen_func, with_mask=with_mask) + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, func_name=self.mask_check_func, n=2, m=4)) + + +class TestASPDynamicPruning1D(TestASPDynamicPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_1d' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D + + +class TestASPDynamicPruning2DBest(TestASPDynamicPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_2d_best' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D + + +class TestASPDynamicPruning2DGreedy(TestASPDynamicPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_2d_greedy' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py new file mode 100644 index 0000000000000000000000000000000000000000..a9986f24b0265fdd83bceb2b09d712f6ed731af1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_static.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import threading, time +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + +paddle.enable_static() + + +class TestASPStaticPruningBase(unittest.TestCase): + def setUp(self): + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 24, 24], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=2, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, size=32, act='softmax') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, self.predict = build_model() + + self.set_config() + + def set_config(self): + self.mask_gen_func = 'mask_1d' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D + + def test_inference_pruning(self): + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + + self.__pruning_and_checking(exe, place, False) + + def test_training_pruning(self): + with fluid.program_guard(self.main_program, self.startup_program): + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=self.predict, label=self.label)) + optimizer = paddle.incubate.asp.decorate( + fluid.optimizer.SGD(learning_rate=0.01)) + optimizer.minimize(loss, self.startup_program) + + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + + self.__pruning_and_checking(exe, place, True) + + def __pruning_and_checking(self, exe, place, with_mask): + exe.run(self.startup_program) + paddle.incubate.asp.prune_model( + self.main_program, + mask_algo=self.mask_gen_func, + with_mask=with_mask) + for param in self.main_program.global_block().all_parameters(): + if ASPHelper._is_supported_layer(self.main_program, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, func_name=self.mask_check_func, n=2, m=4)) + + +class TestASPStaticPruning1D(TestASPStaticPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_1d' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D + + +class TestASPStaticPruning2DBest(TestASPStaticPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_2d_best' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D + + +class TestASPStaticPruning2DGreedy(TestASPStaticPruningBase): + def set_config(self): + self.mask_gen_func = 'mask_2d_greedy' + self.mask_check_func = paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py b/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py new file mode 100644 index 0000000000000000000000000000000000000000..653cbbf84091b1e9ced8efc5a85b16eee8c01f4c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_save_load.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.conv1 = paddle.nn.Conv2D( + in_channels=3, out_channels=4, kernel_size=3, padding=2) + self.linear1 = paddle.nn.Linear(4624, 32) + self.linear2 = paddle.nn.Linear(32, 32) + self.linear3 = paddle.nn.Linear(32, 10) + + def forward(self, img): + hidden = self.conv1(img) + hidden = paddle.flatten(hidden, start_axis=1) + hidden = self.linear1(hidden) + hidden = self.linear2(hidden) + prediction = self.linear3(hidden) + return prediction + + +class TestASPDynamicOptimize(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + paddle.incubate.asp.prune_model(self.layer) + + def test_save_and_load(self): + path = "/tmp/paddle_asp_save_dy/" + net_path = path + "asp_net.pdparams" + opt_path = path + "asp_opt.pdopt" + + paddle.save(self.layer.state_dict(), net_path) + paddle.save(self.optimizer.state_dict(), opt_path) + + asp_info = ASPHelper._get_program_asp_info( + paddle.static.default_main_program()) + for param_name in asp_info.mask_vars: + mask = asp_info.mask_vars[param_name] + asp_info.update_mask_vars( + param_name, paddle.ones( + shape=mask.shape, dtype=mask.dtype)) + asp_info.update_masks(param_name, np.ones(shape=mask.shape)) + + net_state_dict = paddle.load(net_path) + opt_state_dict = paddle.load(opt_path) + + self.layer.set_state_dict(net_state_dict) + self.optimizer.set_state_dict(opt_state_dict) + + imgs = paddle.to_tensor( + np.random.randn(64, 3, 32, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestASPStaticOptimize(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=4, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, size=32, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, predict = build_model() + self.loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict, label=self.label)) + self.optimizer = fluid.optimizer.SGD(learning_rate=0.01) + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + self.optimizer.minimize(self.loss, self.startup_program) + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + self.exe = fluid.Executor(self.place) + self.exe.run(self.startup_program) + + paddle.incubate.asp.prune_model(self.main_program) + + def test_save_and_load(self): + path = "/tmp/paddle_asp_save_st/" + param_path = path + "asp.pdparams" + model_path = path + "asp.pdmodel" + + paddle.save(self.main_program.state_dict(), param_path) + paddle.save(self.main_program, model_path) + + prog = paddle.load(model_path) + + state_dict = paddle.load(param_path) + prog.set_state_dict(state_dict) + + feeder = fluid.DataFeeder( + feed_list=[self.img, self.label], place=self.place) + + data = (np.random.randn(64, 3, 32, 32), np.random.randint( + 10, size=(64, 1))) + self.exe.run(prog, feed=feeder.feed([data])) + + for param in prog.global_block().all_parameters(): + if ASPHelper._is_supported_layer(prog, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py index 4aac878763b6f6f7cab09ca5cdc3cfeab0f49d6d..67ec54367d38271c924f51ec9a981bb1bc35341d 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py @@ -18,7 +18,6 @@ from __future__ import print_function import unittest import threading, time import paddle -from paddle.static import sparsity import numpy as np @@ -41,9 +40,9 @@ class TestASPUtils(unittest.TestCase): x = np.array([[1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]) - self.assertEqual(sparsity.calculate_density(x), 0.56) + self.assertEqual(paddle.incubate.asp.calculate_density(x), 0.56) x[:, 0] = 0.0 - self.assertEqual(sparsity.calculate_density(x), 0.4) + self.assertEqual(paddle.incubate.asp.calculate_density(x), 0.4) def test_check_mask_1d(self): x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], @@ -219,3 +218,7 @@ class TestASPUtils(unittest.TestCase): func_name=paddle.fluid.contrib.sparsity.CheckMethod.CHECK_2D, n=2, m=4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py deleted file mode 100644 index 074aedb947613cff93d407344fdc8a508c665d04..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.base.role_maker as role_maker -import unittest -import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -import os -from paddle.static import sparsity -from paddle.fluid.contrib.sparsity.asp import ASPHelper -import numpy as np -cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') -if cuda_visible_devices is None or cuda_visible_devices == "": - os.environ['CUDA_VISIBLE_DEVICES'] = '0' -else: - os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0] - -paddle.enable_static() - - -class TestFleetWithASP(unittest.TestCase): - def setUp(self): - os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" - os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" - os.environ["PADDLE_TRAINERS_NUM"] = "1" - os.environ["PADDLE_TRAINER_ID"] = "0" - - def net(self, main_prog, startup_prog): - with fluid.program_guard(main_prog, startup_prog): - input_x = paddle.static.data( - name="x", shape=[-1, 32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') - - fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh') - prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') - cost = fluid.layers.cross_entropy(input=prediction, label=input_y) - avg_cost = paddle.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.asp = True - return avg_cost, strategy, input_x, input_y - - def test_with_asp(self): - fleet.init(is_collective=True) - train_prog, startup_prog = fluid.Program(), fluid.Program() - avg_cost, strategy, input_x, input_y = self.net(train_prog, - startup_prog) - - with fluid.program_guard(train_prog, startup_prog): - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer( - optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda( - ) else fluid.CPUPlace() - - exe = fluid.Executor(place) - feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) - exe.run(startup_prog) - - sparsity.prune_model(train_prog) - - data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) - exe.run(train_prog, feed=feeder.feed([data])) - - for param in train_prog.global_block().all_parameters(): - if ASPHelper._is_supported_layer(train_prog, param.name): - mat = np.array(fluid.global_scope().find_var(param.name) - .get_tensor()) - self.assertTrue( - paddle.fluid.contrib.sparsity.check_sparsity( - mat.T, n=2, m=4)) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..3ced15bf15881f1d927a59817ef8be5428a5be84 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_dynamic.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import os +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np +cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') +if cuda_visible_devices is None or cuda_visible_devices == "": + os.environ['CUDA_VISIBLE_DEVICES'] = '0' +else: + os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0] + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super(MyLayer, self).__init__() + self.linear1 = paddle.nn.Linear(32, 32) + self.linear2 = paddle.nn.Linear(32, 10) + + def forward(self, x): + hidden = self.linear1(x) + prediction = self.linear2(hidden) + return prediction + + +class TestFleetWithASPDynamic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_with_asp(self): + fleet.init(is_collective=True) + + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + paddle.incubate.asp.prune_model(self.layer) + + self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.layer = fleet.distributed_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + + output = self.layer(imgs) + loss = loss_fn(output, labels) + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestFleetWithASPAMPDynamic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + self.layer = MyLayer() + + self.place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + self.place = paddle.CUDAPlace(0) + + self.optimizer = paddle.optimizer.SGD( + learning_rate=0.01, parameters=self.layer.parameters()) + + def test_with_asp(self): + fleet.init(is_collective=True) + + self.optimizer = paddle.incubate.asp.decorate(self.optimizer) + paddle.incubate.asp.prune_model(self.layer) + + self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.layer = fleet.distributed_model(self.layer) + + imgs = paddle.to_tensor( + np.random.randn(64, 32), + dtype='float32', + place=self.place, + stop_gradient=False) + labels = paddle.to_tensor( + np.random.randint( + 10, size=(64, 1)), + dtype='float32', + place=self.place, + stop_gradient=False) + + loss_fn = paddle.nn.MSELoss(reduction='mean') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + with paddle.amp.auto_cast(enable=True): + output = self.layer(imgs) + loss = loss_fn(output, labels) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(self.optimizer, scaled) + self.optimizer.clear_grad() + + for param in self.layer.parameters(): + if ASPHelper._is_supported_layer( + paddle.static.default_main_program(), param.name): + mat = param.numpy() + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py similarity index 67% rename from python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py rename to python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py index a34d7e69872e2178182bdfd5b871915bdd5d60dd..2023c0051401fbbed0e604dab278f07ed1c90b54 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_amp.py +++ b/python/paddle/fluid/tests/unittests/asp/test_fleet_with_asp_static.py @@ -1,5 +1,5 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,62 @@ else: paddle.enable_static() -class TestFleetWithASP(unittest.TestCase): +class TestFleetWithASPStatic(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + os.environ["PADDLE_TRAINER_ID"] = "0" + + def net(self, main_prog, startup_prog): + with fluid.program_guard(main_prog, startup_prog): + input_x = paddle.static.data( + name="x", shape=[-1, 32], dtype='float32') + input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') + + fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh') + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + avg_cost = paddle.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.asp = True + return avg_cost, strategy, input_x, input_y + + def test_with_asp(self): + fleet.init(is_collective=True) + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy, input_x, input_y = self.net(train_prog, + startup_prog) + + with fluid.program_guard(train_prog, startup_prog): + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + place = fluid.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place) + exe.run(startup_prog) + + sparsity.prune_model(train_prog) + + data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1))) + exe.run(train_prog, feed=feeder.feed([data])) + + for param in train_prog.global_block().all_parameters(): + if ASPHelper._is_supported_layer(train_prog, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + paddle.fluid.contrib.sparsity.check_sparsity( + mat.T, n=2, m=4)) + + +class TestFleetWithASPAMPStatic(unittest.TestCase): def setUp(self): os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213" os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213" diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 0ae005430e03b046d609c393fcc0641a0d3db49e..28e03fdfd70e1e8fe18a083de49f51de2f76c3c8 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index ff7a167f1a670a1c7e64ef0ac0653092dfa20966..c354baf3b43b77218aa4a6988d0fba3163f18746 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -32,6 +32,7 @@ import paddle.incubate.autograd import paddle.incubate.autotune from . import nn #noqa: F401 +from . import asp #noqa: F401 __all__ = [ 'LookAhead', diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py b/python/paddle/incubate/asp/__init__.py similarity index 51% rename from python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py rename to python/paddle/incubate/asp/__init__.py index 7a3fa0244930c320b0dd8557f872410dbbb6ed65..59f794ef28aa41b61da84ee8edbe2f8ea582d893 100644 --- a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py +++ b/python/paddle/incubate/asp/__init__.py @@ -13,25 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest -import paddle -from paddle.static import sparsity -from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase - -paddle.enable_static() - - -class TestASPHelperPruning1D(TestASPHelperPruningBase): - def test_1D_inference_pruning(self): - self.run_inference_pruning_test( - 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) - - def test_1D_training_pruning(self): - self.run_training_pruning_test( - 'mask_1d', paddle.fluid.contrib.sparsity.CheckMethod.CHECK_1D) - - -if __name__ == '__main__': - unittest.main() +from ...fluid.contrib.sparsity import calculate_density #noqa: F401 +from ...fluid.contrib.sparsity import decorate #noqa: F401 +from ...fluid.contrib.sparsity import prune_model #noqa: F401 +from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 +from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 + +__all__ = [ #noqa + 'calculate_density', + 'decorate', + 'prune_model', + 'set_excluded_layers', + 'reset_excluded_layers' +] diff --git a/python/paddle/static/sparsity/__init__.py b/python/paddle/static/sparsity/__init__.py index 59f794ef28aa41b61da84ee8edbe2f8ea582d893..b4543b8d000fc10d4a866a16916a9e905a9eae26 100644 --- a/python/paddle/static/sparsity/__init__.py +++ b/python/paddle/static/sparsity/__init__.py @@ -16,8 +16,14 @@ from ...fluid.contrib.sparsity import calculate_density #noqa: F401 from ...fluid.contrib.sparsity import decorate #noqa: F401 from ...fluid.contrib.sparsity import prune_model #noqa: F401 -from ...fluid.contrib.sparsity import set_excluded_layers #noqa: F401 from ...fluid.contrib.sparsity import reset_excluded_layers #noqa: F401 +from ...fluid.contrib import sparsity #noqa: F401 + + +def set_excluded_layers(main_program, param_names): + sparsity.set_excluded_layers( + param_names=param_names, main_program=main_program) + __all__ = [ #noqa 'calculate_density', diff --git a/python/setup.py.in b/python/setup.py.in index c1a6e3d3947a9c64115a8af8ae2685a6df176df1..2a0d745729aabab2f8f37856f228854c8ce61b9f 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -281,6 +281,7 @@ packages=['paddle', 'paddle.incubate.tensor', 'paddle.incubate.multiprocessing', 'paddle.incubate.nn', + 'paddle.incubate.asp', 'paddle.incubate.passes', 'paddle.distribution', 'paddle.distributed.sharding',