diff --git a/python/paddle/fluid/contrib/sparsity/__init__.py b/python/paddle/fluid/contrib/sparsity/__init__.py index f78ea1b1c38b85b04ab0e09757ec4d6eea5eaf4d..b36a79b8ca865e4b982ef0315f023525045fc069 100644 --- a/python/paddle/fluid/contrib/sparsity/__init__.py +++ b/python/paddle/fluid/contrib/sparsity/__init__.py @@ -15,7 +15,22 @@ from __future__ import print_function -from . import utils -from .utils import * +from .utils import calculate_density +from .utils import check_mask_1d +from .utils import get_mask_1d +from .utils import check_mask_2d +from .utils import get_mask_2d_greedy +from .utils import get_mask_2d_best +from .utils import create_mask +from .utils import check_sparsity +from .utils import MaskAlgo +from .utils import CheckMethod +from .asp import decorate, prune_model +from .asp import set_excluded_layers, reset_excluded_layers -__all__ = utils.__all__ +__all__ = [ + 'calculate_density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d', + 'get_mask_2d_greedy', 'get_mask_2d_best', 'create_mask', 'check_sparsity', + 'MaskAlgo', 'CheckMethod', 'decorate', 'prune_model', 'set_excluded_layers', + 'reset_excluded_layers' +] diff --git a/python/paddle/fluid/contrib/sparsity/asp.py b/python/paddle/fluid/contrib/sparsity/asp.py new file mode 100644 index 0000000000000000000000000000000000000000..fbabc73f37bce5ca42c292572ee4082e1706bea2 --- /dev/null +++ b/python/paddle/fluid/contrib/sparsity/asp.py @@ -0,0 +1,497 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Functions for Auto SParsity (ASP) training and inference. +""" + +import copy +import numpy as np +import paddle +from paddle.fluid import framework, global_scope, program_guard, layers +from paddle.fluid.initializer import ConstantInitializer +from paddle.fluid.contrib import sparsity +from paddle.fluid import core + +__all__ = [ + 'decorate', 'prune_model', 'set_excluded_layers', 'reset_excluded_layers' +] + + +def set_excluded_layers(main_program, param_names): + r""" + Set parameter name of layers which would not be pruned as sparse weights. + + Args: + main_program (Program, optional): Program with model definition and its parameters. + param_names (list): A list contains names of parameters. + """ + ASPHelper.set_excluded_layers( + main_program=main_program, param_names=param_names) + + +def reset_excluded_layers(main_program=None): + r""" + Reset exculded layers setting corresponding to :attr:`main_program`. If :attr:`main_program` + is None, then all configurations of excluded_layers would be cleaned. + + Args: + main_program (Program, optional): Program with model definition and its parameters. + """ + ASPHelper.reset_excluded_layers(main_program=main_program) + + +def decorate(optimizer): + r""" + Wrap the given optimizer as a OptimizerWithSparsityGuarantee, + which would insert necessary ops for ASP workflows when calling minimize() + + Args: + optimizer (Optimizer): A Optimizer used for training. + Returns: + OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer. + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.contrib import sparsity + + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + input_data = fluid.layers.data(name='data', shape=[None, 128]) + label = fluid.layers.data(name='label', shape=[None, 10]) + hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None) + prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None) + loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label)) + + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + + optimizer = sparsity.decorate(optimizer) + optimizer.minimize(loss, startup_program) + + # When apply distributed training with Fleet + import paddle.distributed.fleet as fleet + + optimizer = fluid.optimizer.SGD(learning_rate=0.1) + optimizer = sparsity.decorate(optimizer) # Need to be called before `fleet.distributed_optimizer` + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(loss, startup_program) + """ + return ASPHelper.decorate(optimizer) + + +def prune_model(place, + main_program=None, + n=2, + m=4, + func_name=sparsity.MaskAlgo.MASK_1D, + with_mask=True): + r""" + Pruning parameters of supported layers in :attr:`main_program` via + specified mask generation function given by :attr:`func_name`. This + function supports both training and inference controlled by :attr:`with_mask`. + If :attr:`with_mask` is True, it would also prune parameter related ASP mask Variables, + else only prunes parameters. + + *Note*: If parameters are supported and in FP16, please set :attr:`n`=2, :attr:`m`=4, + if they in FP32, then :attr:`n`=1, :attr:`m`=2` to further enable Sparse Tensor Core acceleration. + + *Note*: If calling this function with :attr:`with_mask`, it should call `OptimizerWithSparsityGuarantee.minimize` + and initialization (`exe.run(startup_program`)) before (For successfully obtain mask Variable). + Typically set `with_mask` as true for training (have called `OptimizerWithSparsityGuarantee.minimize`) and false for + inference only. To obtain OptimizerWithSparsityGuarantee, please see `sparsity.decoreate()`. + + Args: + place (fluid.CPUPlace()|fluid.CUDAPlace(N)): Device place for pruned parameter and mask Variables, and N means the GPU's id. It should be the same as created instance of Executor. + main_program (Program, optional): Program with model definition and its parameters. Default is `paddle.static.default_main_program() + n (int): n of `n:m` sparse pattern. + m (int): m of `n:m` sparse pattern. + func_name (MaskAlgo, optional): The function name to generate spase mask. Default is `MaskAlgo.MASK_1D`. All options please refer to `MaskAlgo`. + with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True. + Returns: + dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.contrib import sparsity + + main_program = fluid.Program() + startup_program = fluid.Program() + + place = fluid.CUDAPlace(0) + + with fluid.program_guard(main_program, startup_program): + input_data = fluid.layers.data(name='data', shape=[None, 128]) + label = fluid.layers.data(name='label', shape=[None, 10]) + hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None) + prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None) + loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label)) + + optimizer = decorate(fluid.optimizer.SGD(learning_rate=0.1)) + optimizer.minimize(optimizer, loss, main_program, startup_program) + + exe = fluid.Executor(place) + exe.run(startup_program) + + # Must call `exe.run(startup_program)` first before calling `sparsity.prune_model` + sparsity.prune_model(place, main_program, func_name=sparsity.MaskAlgo.MASK_2D_BEST) + """ + return ASPHelper.prune_model( + place=place, + main_program=main_program, + n=n, + m=m, + func_name=func_name, + with_mask=with_mask) + + +class ProgramASPInfo(object): + r""" + ProgramASPInfo is a container to keep ASP relevant information of Pragrom. It contains three inner-variables: + 1. __mask_vars (Dictionary): Key is parameter's name and vaule is its corresponding sparse mask Variable object, which is created by `ASPHelper.create_mask_variables`. + 2. __masks (Dictionary): Key is parameter's name and vaule is its corressponding sparse mask Numpy array, which is created by `ASPHelper.prune_model`. + 3. __excluded_layers (List): It stores name of layers which should not involve into ASP workflow. + """ + + def __init__(self): + self.__mask_vars = {} + self.__masks = {} + self.__excluded_layers = [] + + def update_mask_vars(self, param_name, var): + self.__mask_vars[param_name] = var + + def update_masks(self, param_name, var): + self.__masks[param_name] = var + + def update_excluded_layers(self, param_names): + self.__excluded_layers.extend(copy.deepcopy(param_names)) + + def reset_excluded_layers(self): + self.__excluded_layers = [] + + @property + def mask_vars(self): + return self.__mask_vars + + @property + def masks(self): + return self.__masks + + @property + def excluded_layers(self): + return self.__excluded_layers + + +class ASPHelper(object): + r""" + ASPHelper is a collection of Auto SParsity (ASP) functions to enable + + 1. training models with weights in 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 from scratch. + 2. pruning well-trained models into 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 for fine-tuning. + """ + + MASK_APPENDDED_NAME = '_asp_mask' + SUPPORTED_LAYERS = {'fc': 'w_0', 'linear': 'w_0', 'conv2d': 'w_0'} + + __asp_info = {} + + @classmethod + def set_excluded_layers(cls, main_program, param_names): + r""" + This is the implementation of `sparsity.set_excluded_layers`, for details please see explanation in `sparsity.set_excluded_layers`. + """ + asp_info = cls._get_program_asp_info(main_program) + asp_info.update_excluded_layers(param_names) + + @classmethod + def reset_excluded_layers(cls, main_program=None): + r""" + This is the implementation of `sparsity.reset_excluded_layers`, for details please see explanation in `sparsity.reset_excluded_layers`. + """ + if main_program is None: + for asp_info in cls.__asp_info: + asp_info.reset_excluded_layers() + else: + cls._get_program_asp_info(main_program).reset_excluded_layers() + + @staticmethod + def decorate(optimizer): + r""" + This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`. + """ + return OptimizerWithSparsityGuarantee(optimizer) + + @classmethod + def prune_model(cls, + place, + main_program=None, + n=2, + m=4, + func_name=sparsity.MaskAlgo.MASK_1D, + with_mask=True): + r""" + This is the implementation of `sparsity.prune_model`, for details please see explanation in `sparsity.prune_model`. + """ + checked_func_name = sparsity.CheckMethod.get_checking_method(func_name) + + if main_program is None: + main_program = paddle.static.default_main_program() + + asp_info = cls._get_program_asp_info(main_program) + for param in main_program.global_block().all_parameters(): + if ASPHelper._is_supported_layer(main_program, param.name): + weight_tensor = global_scope().find_var(param.name).get_tensor() + weight_nparray = np.array(weight_tensor) + + # The double transpose ops here make sure pruning direction consistent with cuSparseLt. + # SPMMA in cuSparseLt: D = (AxB) + C, where matrix A (mxk) is sparse matrix. + # cuSparseLt would prune matrix A along k dimension. + # In sparse training, layer weight matriices is viewed sparse matrix A, so + # the math fomula should be 'Act(WX + b)'. However, default fomula in PaddlePaddle + # is 'Act(XW + b)'. For enabling SPMMA, weights and inputs should be transposed + # for computing, Act( (W^T X^T)^T + b). Therefore, we have to prune alog k dimension + # of W^T, which is m dimension of W. Moreove, all mask generating functions in + # sparsity/utils is row-major pruning. That is the reason we have to transpose weight + # matrices beforce invoking create_mask. Then we transpose the result maks to make + # sure its shape to be the same as the input weight. + weight_sparse_mask = sparsity.create_mask( + weight_nparray.T, func_name=func_name, n=n, m=m).T + weight_pruned_nparray = np.multiply(weight_nparray, + weight_sparse_mask) + weight_tensor.set(weight_pruned_nparray, place) + assert sparsity.check_sparsity(weight_pruned_nparray.T, n=n, m=m, func_name=checked_func_name), \ + 'Pruning {} weight matrix failure!!!'.format(param.name) + if with_mask: + weight_mask_param = global_scope().find_var( + ASPHelper._get_mask_name(param.name)) + assert weight_mask_param is not None, \ + 'Cannot find {} variable, please call ASPHelper.minimize' \ + ' and initialization (exe.run(startup_program)) first!'.format(ASPHelper._get_mask_name(param.name)) + weight_mask_tensor = weight_mask_param.get_tensor() + weight_mask_tensor.set(weight_sparse_mask, place) + asp_info.update_masks(param.name, weight_sparse_mask) + return asp_info.masks.copy() + + @staticmethod + def _get_mask_name(param_name): + r""" + Return mask name by given parameter name :attr:`param_name`. + + Args: + param_name (string): The name of parameter. + Returns: + string: The mask name of :attr:`param_name`. + """ + return param_name + ASPHelper.MASK_APPENDDED_NAME + + @staticmethod + def _get_not_ASP_relevant_vars(main_program): + r""" + Get all parameters's Variables in :attr:`main_program` but excluded ASP mask Variables. + + Args: + main_program (Program): Program with model definition and its parameters. + Returns: + list: A list of parameter Variables in :attr:`main_program` (excluded ASP mask Variables). + """ + var_list = [] + for param in main_program.global_block().all_parameters(): + if ASPHelper.MASK_APPENDDED_NAME not in param.name: + var_list.append(param) + return var_list + + @classmethod + def _get_program_asp_info(cls, main_program): + if not main_program in cls.__asp_info: + cls.__asp_info[main_program] = ProgramASPInfo() + return cls.__asp_info[main_program] + + @classmethod + def _is_supported_layer(cls, main_program, param_name): + r""" + Verify if given :attr:`param_name` is supported by ASP. + + Args: + param_name (string): The name of parameter. + Returns: + bool: True if it is supported, else False. + Examples: + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.contrib.sparsity.asp import ASPHelper + + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + input_data = fluid.layers.data(name='data', shape=[None, 128]) + fc = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None) + + for param in main_program.global_block().all_parameters(): + ASPHelper._is_supported_layer(main_program, param.name) + # fc_0.w_0 -> True + # fc_0.b_0 -> False + """ + if ASPHelper.MASK_APPENDDED_NAME in param_name: + return False + + for layer in cls._get_program_asp_info(main_program).excluded_layers: + if layer in param_name: + return False + + for name in ASPHelper.SUPPORTED_LAYERS: + if name in param_name and \ + ASPHelper.SUPPORTED_LAYERS[name] in param_name: + return True + return False + + @classmethod + def _minimize(cls, + optimizer, + loss, + main_program=None, + startup_program=None, + parameter_list=None, + no_grad_set=None): + r""" + This function is a decorator of `minimize` function in `Optimizer`. + There are three steps: + + 1. Call :attr:`optimizer`.minimize(:attr:`loss`) + 2. Create sparse mask Tensors according to supported layers in :attr:`main_program`. + 3. Insert masking ops in the end of parameters update. + + *Note*: Please use `ASP.decorate` instead when applying distributed training with `Fleet`. + (Due to there is a invisiable graphs optimization in `Fleet.minimize()` which make training graph + cannot be modified anymore.) + + Args: + optimizer (Optimizer): A Optimizer used for training. + loss (Variable): A Variable containing the value to minimize. + main_program (Program, optional): Program with model definition and its parameters. Default is `loss.block.program`. + startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`. + parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated. + no_grad_set (set, optional): Set of `Variable or `Variable.name` that don't need to be updated. The default value is None. + Returns: + list: operators from :attr:`optimizer`.minimize(:attr:`loss`). + list: pairs of parameters and their gradients. + """ + if main_program is None: + main_program = loss.block.program + + if startup_program is None: + startup_program = paddle.static.default_startup_program() + + optimizer_ops, params_and_grads = optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set=no_grad_set) + cls._create_mask_variables(main_program, startup_program, + params_and_grads) + cls._insert_sparse_mask_ops(main_program, params_and_grads) + return optimizer_ops, params_and_grads + + @classmethod + def _create_mask_variables(cls, main_program, startup_program, + params_and_grads): + r""" + Create sparse mask Tensors according to supported layers in :attr:`main_program`. + This function is called in second step of `ASPHelper._minimize` + + Args: + main_program (Program): Program with model definition and its parameters. + startup_program (Program): Program for initializing parameters. + params_and_grads (list): Variable pairs of parameters and their gradients. + """ + asp_info = cls._get_program_asp_info(main_program) + with program_guard(main_program, startup_program): + for param_and_grad in params_and_grads: + if ASPHelper._is_supported_layer(main_program, + param_and_grad[0].name): + mask_param = layers.create_parameter( + name=param_and_grad[0].name + + ASPHelper.MASK_APPENDDED_NAME, + shape=param_and_grad[0].shape, + dtype=param_and_grad[0].dtype, + default_initializer=ConstantInitializer(value=1.0)) + mask_param.stop_gradient = True + mask_param.trainable = False + asp_info.update_mask_vars(param_and_grad[0].name, + mask_param) + + @classmethod + def _insert_sparse_mask_ops(cls, main_program, param_grads): + r""" + Insert masking ops in the end of parameters update. + This function is called in third step of `ASPHelper._minimize` + + Args: + main_program (Program): Program with model definition and its parameters. + params_and_grads (list): Variable pairs of parameters and their gradients. + """ + block = main_program.global_block() + asp_info = cls._get_program_asp_info(main_program) + for param_grad in param_grads: + if param_grad[0].name in asp_info.mask_vars: + block.append_op( + type='elementwise_mul', + inputs={ + "X": param_grad[0], + 'Y': asp_info.mask_vars[param_grad[0].name] + }, + outputs={'Out': param_grad[0]}, + attrs={'axis': -1, + 'use_mkldnn': False}) + + +class OptimizerWithSparsityGuarantee(object): + r""" + OptimizerWithSparsityGuarantee is a wrapper to decorate `minimize` function of given optimizer by `_minimize` of ASPHelper. + The decorated `minimize` function would do three things (exactly same as `ASPHelper._minimize`): + 1. Call `minimize` function of given optimizer. + 2. Call `ASPHelper._create_mask_variables` to create mask Variables. + 3. Call `ASPHelper._insert_sparse_mask_ops` to insert weight masking ops in the end of `loss`'s Program. + """ + + def __init__(self, optimizer): + self._optimizer = optimizer + self._learning_rate = optimizer._learning_rate + self._learning_rate_map = optimizer._learning_rate_map + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + r""" + This function is to call `ASPHelper.minimize()` and return its return + + Args: + loss (Variable): A Variable containing the value to minimize. + startup_program (Program, optional): Program for initializing parameters in `parameter_list`. Default is `paddle.static.default_startup_program()`. + parameter_list (Iterable, optional): Iterable of `Variable` or `Variable.name` to update to minimize `loss`. The default value is None, at this time all parameters will be updated. + no_grad_set (set, optional): Set of `Variable or `Variable.name` that don't need to be updated. The default value is None. + Returns: + list: operators from :attr:`optimizer`.minimize(:attr:`loss`). + list: pairs of parameters and their gradients. + """ + return ASPHelper._minimize( + self._optimizer, + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) diff --git a/python/paddle/fluid/contrib/sparsity/utils.py b/python/paddle/fluid/contrib/sparsity/utils.py index f1108c327407ff65596156668edd864715291894..bb030cbac1beaf814987e5cf6a21075ff21d58ee 100644 --- a/python/paddle/fluid/contrib/sparsity/utils.py +++ b/python/paddle/fluid/contrib/sparsity/utils.py @@ -27,7 +27,7 @@ from itertools import permutations import threading __all__ = [ - 'density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d', + 'calculate_density', 'check_mask_1d', 'get_mask_1d', 'check_mask_2d', 'get_mask_2d_greedy', 'get_mask_2d_best', 'create_mask', 'check_sparsity', 'MaskAlgo', 'CheckMethod' ] @@ -75,7 +75,7 @@ class CheckMethod(Enum): CheckMethod.get_checking_method(MaskAlgo.MASK_2D_BEST) # CheckMethod.CHECK_2D """ - assert type(mask_algo) == MaskAlgo, \ + assert isinstance(mask_algo, MaskAlgo), \ "mask_algo should be MaskAlgo type" if mask_algo == MaskAlgo.MASK_1D: return CheckMethod.CHECK_1D @@ -83,7 +83,7 @@ class CheckMethod(Enum): return CheckMethod.CHECK_2D -def density(x): +def calculate_density(x): r""" Return the density of the input tensor. @@ -99,15 +99,15 @@ def density(x): x = np.array([[0, 1, 3, 0], [1, 1, 0, 1]]) - sparsity.density(x) # 0.625 + sparsity.calculate_density(x) # 0.625 """ x_flattened = x.flatten() return float(np.nonzero(x_flattened)[0].size) / x_flattened.size -def reshape_1d(mat, m): +def _reshape_1d(mat, m): r""" - Reshape the input matrix to shape (-1, m). + Reshape the input 2D matrix to shape (-1, m). If the second dimension of :attr:`mat` is not a multiples of :attr:`m`, then this function would pad the remainder with 0 before reshaping. @@ -116,11 +116,13 @@ def reshape_1d(mat, m): remainder = mat.shape[1] % m Args: - mat (nparray): The input matrix. + mat (nparray): The input 2D matrix. m (int): The second dimension of reshaped matrix. Returns: tuple: A pair of the reshaped and padded matrix and the shape of padded matrix (non-reshaping). """ + assert len(mat.shape) == 2, "The input mat should be a 2D matrix!" + remainder = mat.shape[1] % m if mat.shape[1] % m > 0: mat_padded = np.zeros((mat.shape[0], mat.shape[1] + (m - remainder))) @@ -165,9 +167,9 @@ def check_mask_1d(mat, n, m): sparsity.check_mask_1d(x, 2, 4) # True """ if len(mat.shape) <= 1: - mat_flattern, shape = reshape_1d(mat.reshape(1, mat.shape[0]), m) + mat_flattern, shape = _reshape_1d(mat.reshape(1, mat.shape[0]), m) else: - mat_flattern, shape = reshape_1d(mat, m) + mat_flattern, shape = _reshape_1d(mat, m) for sub_mat in mat_flattern: if np.nonzero(sub_mat)[0].size > (m - n): @@ -202,7 +204,7 @@ def get_mask_1d(mat, n, m): # [0, 1, 0, 1]]) sparsity.check_mask_1d(mask, 2, 4) # True """ - mat_flattern, shape = reshape_1d(mat, m) + mat_flattern, shape = _reshape_1d(mat, m) mask_flattern = np.ones_like(mat_flattern) mask = np.ones_like(mat) @@ -215,9 +217,9 @@ def get_mask_1d(mat, n, m): return mask -def reshape_2d(mat, m): +def _reshape_2d(mat, m): r""" - Reshape the input matrix to shape (-1, :math:`m \times m`). + Reshape the input 2D matrix to shape (-1, :math:`m \times m`). In each dimension of :attr:`mat`, if it is not a multiples of :attr:`m`, then this function would pad the remainder with 0 before reshaping. @@ -227,11 +229,13 @@ def reshape_2d(mat, m): remainder_1 = mat.shape[1] % m Args: - mat (nparray): The input matrix. + mat (nparray): The input 2D matrix. m (int): The square root of second dimension of reshaped matrix. Returns: tuple: A pair of the reshaped and padded matrix and the shape of padded matrix (non-reshaping). """ + assert len(mat.shape) == 2, "The input mat should be a 2D matrix!" + remainder_0 = mat.shape[0] % m remainder_1 = mat.shape[1] % m @@ -297,7 +301,7 @@ def check_mask_2d(mat, n, m): [1, 1, 0, 1]]) sparsity.check_mask_2d(x, 2, 4) # True """ - mat_padded, shape = reshape_2d(mat, m) + mat_padded, shape = _reshape_2d(mat, m) for sub_mat in mat_padded: sub_mask = np.absolute(np.squeeze(sub_mat.reshape(m, m))) > 0 if (np.sum(np.sum(sub_mask, axis=1) > (m-n)) != 0) and \ @@ -338,7 +342,7 @@ def get_mask_2d_greedy(mat, n, m): # [0. 1. 1. 0.]]) sparsity.check_mask_2d(mask, 2, 4) # True """ - mat_padded, shape = reshape_2d(mat, m) + mat_padded, shape = _reshape_2d(mat, m) mask_padded = np.zeros_like(mat_padded).reshape(-1, m, m) for idx in range(len(mat_padded)): @@ -372,11 +376,11 @@ def get_mask_2d_greedy(mat, n, m): return mask[:mat.shape[0], :mat.shape[1]] -valid_2d_patterns_lock = threading.Lock() -valid_2d_patterns = {} +_valid_2d_patterns_lock = threading.Lock() +_valid_2d_patterns = {} -def compute_valid_2d_patterns(n, m): +def _compute_valid_2d_patterns(n, m): r""" Compute all vaild 2D `n:m` sparse patterns. @@ -389,12 +393,12 @@ def compute_valid_2d_patterns(n, m): Returns: dictionary: A dictionary with key: *m_n* (string) and value: all vaild 2D `n:m` sparse patterns. """ - global valid_2d_patterns_lock - global valid_2d_patterns + global _valid_2d_patterns_lock + global _valid_2d_patterns valid_key = '{}_{}'.format(m, n) - if valid_key in valid_2d_patterns: - return valid_2d_patterns[valid_key] + if valid_key in _valid_2d_patterns: + return _valid_2d_patterns[valid_key] else: patterns = np.zeros(m) patterns[:n] = 1 @@ -407,9 +411,9 @@ def compute_valid_2d_patterns(n, m): valid_patterns = np.empty((valid.shape[0], m, m)) valid_patterns[:] = patterns[valid[:]] - valid_2d_patterns_lock.acquire() - valid_2d_patterns[valid_key] = valid_patterns - valid_2d_patterns_lock.release() + _valid_2d_patterns_lock.acquire() + _valid_2d_patterns[valid_key] = valid_patterns + _valid_2d_patterns_lock.release() return valid_patterns @@ -446,9 +450,9 @@ def get_mask_2d_best(mat, n, m): print("L1 norm of `greedy` sparse matrix", np.multiply(mat, mask_greedy).sum()) # 56 print("L1 norm of `best` sparse matrix", np.multiply(mat, mask_best).sum()) # 61 """ - patterns = compute_valid_2d_patterns(n, m) + patterns = _compute_valid_2d_patterns(n, m) - mat_flattern, shape = reshape_2d(mat, m) + mat_flattern, shape = _reshape_2d(mat, m) mask_flattern = np.ones_like(mat_flattern).reshape(-1, m, m) pmax = np.argmax( np.matmul(mat_flattern, patterns.reshape(patterns.shape[0], m * m).T), @@ -504,30 +508,25 @@ def create_mask(tensor, func_name=MaskAlgo.MASK_1D, n=2, m=4): dtype = tensor.dtype t = tensor.astype(float) - assert type(func_name) == MaskAlgo, \ + assert isinstance(func_name, MaskAlgo), \ "func_name argumet of create_mask is only accepted as type MaskAlgo. " \ "But got {}".format(type(func_name)) func = getattr(sys.modules[__name__], func_name.value, None) if len(shape) == 1: t = t.reshape(1, shape[0]) - mask = func(t, n=n, m=m) - return mask.reshape(shape).astype(dtype) elif len(shape) == 2: t = t.reshape(shape[0], shape[1]) - mask = func(t, n=n, m=m) - return mask.reshape(shape).astype(dtype) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - mask = func(t, n=n, m=m) - return mask.reshape(shape).astype(dtype) # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op elif len(shape) == 4: t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) - mask = func(t, n=n, m=m) - return mask.reshape(shape).astype(dtype) else: - assert True, "The dimension of input tensor is not supported in create_mask, " \ - "Only dimension < 4 is supported but got {}".format(len(shape)) + raise ValueError("The dimension of input tensor is not supported in create_mask, " \ + "Only dimension < 4 is supported but got {}".format(len(shape))) + + mask = func(t, n=n, m=m) + return mask.reshape(shape).astype(dtype) def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4): @@ -569,19 +568,15 @@ def check_sparsity(tensor, func_name=CheckMethod.CHECK_1D, n=2, m=4): func = getattr(sys.modules[__name__], func_name.value, None) if len(shape) == 1: t = t.reshape(1, shape[0]) - return func(t, n=n, m=m) elif len(shape) == 2: t = t.reshape(shape[0], shape[1]) - return func(t, n=n, m=m) elif len(shape) == 3: t = t.reshape(shape[0] * shape[1], shape[2]) - return func(t, n=n, m=m) # 4d-tensor conv (out, in, h, w) -> (out, in*h*w) in GemmConvKernel Op elif len(shape) == 4: t = t.reshape(shape[0], shape[1] * shape[2] * shape[3]) - return func(t, n=n, m=m) else: - assert True, "The dimension of input tensor is not supported in check_sparsity, " \ - "Only dimension < 4 is supported but got {}".format(len(shape)) + raise ValueError("The dimension of input tensor is not supported in create_mask, " \ + "Only dimension < 4 is supported but got {}".format(len(shape))) - return False + return func(t, n=n, m=m) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 144e568c55ca089e387d82b8613b500753ba5d89..03aaf7ed03e26dd1772cbf598c82daedcc716a07 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -661,6 +661,8 @@ if (WITH_MKLDNN) add_subdirectory(mkldnn) endif() +add_subdirectory(asp) + add_subdirectory(ir) if (WITH_TESTING) diff --git a/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f71e04c09aa38b8cf7b3a167b84d4dc0e6cc3ec7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/asp/__init__.py b/python/paddle/fluid/tests/unittests/asp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c551792f989c0611d7077beb4e0995fc2f06abe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py new file mode 100644 index 0000000000000000000000000000000000000000..370d73cc35a43ad02a715e9765cbf5a88a9be535 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/asp_pruning_base.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import threading, time +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + +paddle.enable_static() + + +class TestASPHelperPruningBase(unittest.TestCase): + def setUp(self): + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=4, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, size=32, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, self.predict = build_model() + + def run_inference_pruning_test(self, get_mask_gen_func, + get_mask_check_func): + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + + self.__pruning_and_checking(exe, place, get_mask_gen_func, + get_mask_check_func, False) + + def run_training_pruning_test(self, get_mask_gen_func, get_mask_check_func): + with fluid.program_guard(self.main_program, self.startup_program): + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=self.predict, label=self.label)) + optimizer = sparsity.decorate( + fluid.optimizer.SGD(learning_rate=0.01)) + optimizer.minimize(loss, self.startup_program) + + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + + self.__pruning_and_checking(exe, place, get_mask_gen_func, + get_mask_check_func, True) + + def __pruning_and_checking(self, exe, place, mask_func_name, + check_func_name, with_mask): + exe.run(self.startup_program) + sparsity.prune_model( + place, + self.main_program, + func_name=mask_func_name, + with_mask=with_mask) + for param in self.main_program.global_block().all_parameters(): + if ASPHelper._is_supported_layer(self.main_program, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue( + sparsity.check_sparsity( + mat.T, func_name=check_func_name, n=2, m=4)) diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..402861ad5d93120dd9328b25d2adab07504ff313 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_optimize.py @@ -0,0 +1,202 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import threading, time +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.contrib import sparsity +from paddle.fluid.contrib.sparsity.asp import ASPHelper +import numpy as np + +paddle.enable_static() + + +class TestASPHelper(unittest.TestCase): + def setUp(self): + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + + def build_model(): + img = fluid.data( + name='img', shape=[None, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + hidden = fluid.layers.conv2d( + input=img, num_filters=4, filter_size=3, padding=2, act="relu") + hidden = fluid.layers.fc(input=hidden, size=32, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + return img, label, prediction + + with fluid.program_guard(self.main_program, self.startup_program): + self.img, self.label, predict = build_model() + self.loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict, label=self.label)) + self.optimizer = fluid.optimizer.SGD(learning_rate=0.01) + + def test_get_not_ASP_relevant_vars(self): + def check_params(params, params_from_asp): + if len(params_from_asp) != len(params): + return False + + for i, p in enumerate(params_from_asp): + if p.name != params[i].name: + return False + return True + + params = self.main_program.global_block().all_parameters() + params_from_asp = ASPHelper._get_not_ASP_relevant_vars( + self.main_program) + self.assertTrue(check_params(params, params_from_asp)) + + with fluid.program_guard(self.main_program, self.startup_program): + ASPHelper._minimize(self.optimizer, self.loss, self.main_program, + self.startup_program) + params_from_asp_after_opt = ASPHelper._get_not_ASP_relevant_vars( + self.main_program) + self.assertTrue(check_params(params, params_from_asp_after_opt)) + + def test_is_supported_layers(self): + program = paddle.static.default_main_program() + + names = [ + 'embedding_0.w_0', 'fack_layer_0.w_0', 'conv2d_0.w_0', + 'conv2d_0.b_0', 'conv2d_1.w_0', 'conv2d_1.b_0', 'fc_0.w_0', + 'fc_0.b_0', 'fc_1.w_0', 'fc_1.b_0', 'linear_2.w_0', 'linear_2.b_0' + ] + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + sparsity.set_excluded_layers(program, ['fc_1', 'conv2d_0']) + ref = [ + False, False, False, False, True, False, True, False, False, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + sparsity.reset_excluded_layers(program) + ref = [ + False, False, True, False, True, False, True, False, True, False, + True, False + ] + for i, name in enumerate(names): + self.assertTrue( + ref[i] == ASPHelper._is_supported_layer(program, name)) + + def test_decorate(self): + param_names = self.__get_param_names(self.main_program.global_block() + .all_parameters()) + with fluid.program_guard(self.main_program, self.startup_program): + self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer.minimize(self.loss, self.startup_program) + param_names_after_minimize = self.__get_param_names( + self.main_program.global_block().all_parameters()) + + self.__check_mask_variables_and_ops(param_names, + param_names_after_minimize) + + def test_asp_training(self): + with fluid.program_guard(self.main_program, self.startup_program): + self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer.minimize(self.loss, self.startup_program) + + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place) + + exe.run(self.startup_program) + sparsity.prune_model(place, self.main_program) + + data = (np.random.randn(64, 3, 32, 32), np.random.randint( + 10, size=(64, 1))) + exe.run(self.main_program, feed=feeder.feed([data])) + + for param in self.main_program.global_block().all_parameters(): + if ASPHelper._is_supported_layer(self.main_program, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + + def test_asp_training_with_amp(self): + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with fluid.program_guard(self.main_program, self.startup_program): + self.optimizer = fluid.contrib.mixed_precision.decorator.decorate( + self.optimizer) + self.optimizer = sparsity.decorate(self.optimizer) + self.optimizer.minimize(self.loss, self.startup_program) + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder( + feed_list=[self.img, self.label], place=place) + + exe.run(self.startup_program) + sparsity.prune_model(place, self.main_program) + + data = (np.random.randn(64, 3, 32, 32), np.random.randint( + 10, size=(64, 1))) + exe.run(self.main_program, feed=feeder.feed([data])) + + for param in self.main_program.global_block().all_parameters(): + if ASPHelper._is_supported_layer(self.main_program, param.name): + mat = np.array(fluid.global_scope().find_var(param.name) + .get_tensor()) + self.assertTrue(sparsity.check_sparsity(mat.T, n=2, m=4)) + + def __get_param_names(self, params): + param_names = [] + for p in params: + param_names.append(p.name) + return param_names + + def __check_mask_variables_and_ops(self, param_names, + param_names_after_minimize): + for n in param_names: + self.assertFalse(ASPHelper._is_supported_layer(self.main_program, n) and \ + ASPHelper._get_mask_name(n) not in param_names_after_minimize) + + mask_names = [] + for n in param_names: + if ASPHelper._is_supported_layer(self.main_program, n): + mask_names.append(ASPHelper._get_mask_name(n)) + + masking_ops = [] + for op in self.main_program.global_block().ops: + if op.type == 'elementwise_mul' and \ + op.input('Y')[0] in mask_names: + masking_ops.append(op.input('Y')[0]) + + self.assertTrue(len(masking_ops) == len(mask_names)) + for n in masking_ops: + self.assertTrue(n in mask_names) + + for n in mask_names: + self.assertTrue(n in masking_ops) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4b2c002f5afaf390b42e13e5cf7f34906cd90a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_1d.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +from paddle.fluid.contrib import sparsity +from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase + +paddle.enable_static() + + +class TestASPHelperPruning1D(TestASPHelperPruningBase): + def test_1D_inference_pruning(self): + self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_1D, + sparsity.CheckMethod.CHECK_1D) + + def test_1D_training_pruning(self): + self.run_training_pruning_test(sparsity.MaskAlgo.MASK_1D, + sparsity.CheckMethod.CHECK_1D) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8b1e4a06ae4c6954aba4f380361dfc7383eb9b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_best.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +from paddle.fluid.contrib import sparsity +from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase + +paddle.enable_static() + + +class TestASPHelperPruning2DBest(TestASPHelperPruningBase): + def test_2D_best_inference_pruning(self): + self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_2D_BEST, + sparsity.CheckMethod.CHECK_2D) + + def test_2D_best_training_pruning(self): + self.run_training_pruning_test(sparsity.MaskAlgo.MASK_2D_BEST, + sparsity.CheckMethod.CHECK_2D) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py new file mode 100644 index 0000000000000000000000000000000000000000..4bdd310f0209a94f639b107e7279726e196e6a7d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_pruning_2d_greedy.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +from paddle.fluid.contrib import sparsity +from paddle.fluid.tests.unittests.asp.asp_pruning_base import TestASPHelperPruningBase + +paddle.enable_static() + + +class TestASPHelperPruning2DGreedy(TestASPHelperPruningBase): + def test_2D_greedy_inference_pruning(self): + self.run_inference_pruning_test(sparsity.MaskAlgo.MASK_2D_GREEDY, + sparsity.CheckMethod.CHECK_2D) + + def test_2D_greedy_training_pruning(self): + self.run_training_pruning_test(sparsity.MaskAlgo.MASK_2D_GREEDY, + sparsity.CheckMethod.CHECK_2D) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_asp_utils.py b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py similarity index 94% rename from python/paddle/fluid/tests/unittests/test_asp_utils.py rename to python/paddle/fluid/tests/unittests/asp/test_asp_utils.py index faffd477ae5661cee5e599b2044af4f42a96112f..387cb55e5c3cfd65c6e56433afb659dfe2f12bff 100644 --- a/python/paddle/fluid/tests/unittests/test_asp_utils.py +++ b/python/paddle/fluid/tests/unittests/asp/test_asp_utils.py @@ -39,9 +39,9 @@ class TestASPUtils(unittest.TestCase): x = np.array([[1.0, 1.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 1.0]]) - self.assertEqual(sparsity.density(x), 0.56) + self.assertEqual(sparsity.calculate_density(x), 0.56) x[:, 0] = 0.0 - self.assertEqual(sparsity.density(x), 0.4) + self.assertEqual(sparsity.calculate_density(x), 0.4) def test_check_mask_1d(self): x = np.array([[1.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 1.0], @@ -114,11 +114,11 @@ class TestASPUtils(unittest.TestCase): for _ in range(4): computing_thread = threading.Thread( target=paddle.fluid.contrib.sparsity.utils. - compute_valid_2d_patterns, + _compute_valid_2d_patterns, args=(2, 4)) computing_thread.start() time.sleep(3) - patterns_map = paddle.fluid.contrib.sparsity.utils.valid_2d_patterns + patterns_map = paddle.fluid.contrib.sparsity.utils._valid_2d_patterns reference_patterns = get_reference() reference_key = '4_2'