diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index a4b237964ef0dd672238744faeca1a24a094fb7b..bd946ab26aac7d92407e36de9cfef6c814efe0e9 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -50,6 +50,7 @@ add_arg('pruning_steps', int, 100, "How many times you want to increase y add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15") add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base") add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None") +add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False") # yapf: enable @@ -96,12 +97,14 @@ def create_unstructured_pruner(model, args, configs=None): mode=args.pruning_mode, ratio=args.ratio, threshold=args.threshold, - prune_params_type=args.prune_params_type) + prune_params_type=args.prune_params_type, + local_sparsity=args.local_sparsity) else: return GMPUnstructuredPruner( model, ratio=args.ratio, prune_params_type=args.prune_params_type, + local_sparsity=args.local_sparsity, configs=configs) @@ -270,7 +273,6 @@ def compress(args): train_reader_cost = 0.0 train_run_cost = 0.0 total_samples = 0 - reader_start = time.time() for i in range(args.last_epoch + 1, args.num_epochs): diff --git a/demo/unstructured_prune/train.py b/demo/unstructured_prune/train.py index 1c0362def8df981a8d1052846616c52607801149..3a0174793d9332973f2ccb12a0df77aae74cfac6 100644 --- a/demo/unstructured_prune/train.py +++ b/demo/unstructured_prune/train.py @@ -49,6 +49,7 @@ add_arg('tunning_epochs', int, 60, "The epoch numbers used to tune add_arg('pruning_steps', int, 120, "How many times you want to increase your ratio during training. Default: 120") add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15") add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None") +add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False") # yapf: enable model_list = models.__all__ @@ -96,13 +97,15 @@ def create_unstructured_pruner(train_program, args, place, configs): ratio=args.ratio, threshold=args.threshold, prune_params_type=args.prune_params_type, - place=place) + place=place, + local_sparsity=args.local_sparsity) else: return GMPUnstructuredPruner( train_program, ratio=args.ratio, prune_params_type=args.prune_params_type, place=place, + local_sparsity=args.local_sparsity, configs=configs) diff --git a/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst index ad0825e5662e984dd8943306361ef8f98d5ad440..9beb48cd8b9213fff7e9a3d59f33e76eee54b272 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst @@ -4,7 +4,7 @@ UnstructuredPruner ---------- -.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None) +.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False) `源代码 `_ @@ -44,6 +44,8 @@ UnstructuredPruner .. +- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 + **返回:** 一个UnstructuredPruner类的实例。 **示例代码:** @@ -203,7 +205,7 @@ GMPUnstructuredPruner `源代码 `_ -.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, configs=None) +.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None) 该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。 @@ -213,6 +215,7 @@ GMPUnstructuredPruner - **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。 - **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 +- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下: .. code-block:: python diff --git a/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst index 61fc3bea96370462678f63160aeed910898287c2..7a19b8db43033522651540c85362ae684b2db26f 100644 --- a/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst +++ b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst @@ -4,7 +4,7 @@ UnstrucuturedPruner ---------- -.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None) +.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None, local_sparsity=False) `源代码 `_ @@ -43,6 +43,7 @@ UnstrucuturedPruner .. +- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 **返回:** 一个UnstructuredPruner类的实例 @@ -280,7 +281,7 @@ UnstrucuturedPruner GMPUnstrucuturedPruner ---------- -.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, configs=None) +.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None) `源代码 `_ @@ -294,6 +295,7 @@ GMPUnstrucuturedPruner - **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。 - **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 +- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。 - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下: .. code-block:: python diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 9c40a538a0dcf9dca8ab483b7c8eac5ffce61bd9..4e7ef182a7073f62f763338e3dd2bc3656ef32eb 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -24,6 +24,7 @@ class UnstructuredPruner(): - ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55 - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. + - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False """ def __init__(self, @@ -32,14 +33,19 @@ class UnstructuredPruner(): threshold=0.01, ratio=0.55, prune_params_type=None, - skip_params_func=None): + skip_params_func=None, + local_sparsity=False): assert mode in ('ratio', 'threshold' ), "mode must be selected from 'ratio' and 'threshold'" assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." + if local_sparsity: + assert mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly." self.model = model self.mode = mode self.threshold = threshold self.ratio = ratio + self.local_sparsity = local_sparsity + self.thresholds = {} # Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params if skip_params_func is not None: @@ -91,11 +97,19 @@ class UnstructuredPruner(): continue t_param = param.value().get_tensor() v_param = np.array(t_param) - params_flatten.append(v_param.flatten()) - params_flatten = np.concatenate(params_flatten, axis=0) - total_length = params_flatten.size - self.threshold = np.sort(np.abs(params_flatten))[max( - 0, round(self.ratio * total_length) - 1)].item() + if self.local_sparsity: + flatten_v_param = v_param.flatten() + cur_length = flatten_v_param.size + cur_threshold = np.sort(np.abs(flatten_v_param))[max( + 0, round(self.ratio * cur_length) - 1)].item() + self.thresholds[param.name] = cur_threshold + else: + params_flatten.append(v_param.flatten()) + if not self.local_sparsity: + params_flatten = np.concatenate(params_flatten, axis=0) + total_length = params_flatten.size + self.threshold = np.sort(np.abs(params_flatten))[max( + 0, round(self.ratio * total_length) - 1)].item() def _update_masks(self): for name, sub_layer in self.model.named_sublayers(): @@ -105,7 +119,11 @@ class UnstructuredPruner(): if param.name in self.skip_params: continue mask = self.masks.get(param.name) - bool_tmp = (paddle.abs(param) >= self.threshold) + if self.local_sparsity: + bool_tmp = ( + paddle.abs(param) >= self.thresholds[param.name]) + else: + bool_tmp = (paddle.abs(param) >= self.threshold) paddle.assign(bool_tmp, output=mask) def summarize_weights(self, model, ratio=0.1): @@ -248,6 +266,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): - ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55 - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. + - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None .. code-block:: python @@ -268,11 +287,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): ratio=0.55, prune_params_type=None, skip_params_func=None, + local_sparsity=False, configs=None): assert configs is not None, "Configs must be passed in for GMP pruner." super(GMPUnstructuredPruner, self).__init__( - model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func) + model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func, + local_sparsity) self.stable_iterations = configs.get('stable_iterations') self.pruning_iterations = configs.get('pruning_iterations') self.tunning_iterations = configs.get('tunning_iterations') diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index 5f457b3419fda94675ce5d0c912fa519ea046a2c..cf9064aee5d490a2ebf2a8fe287f0cd9cd38a546 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -20,6 +20,7 @@ class UnstructuredPruner(): - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None + - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False """ def __init__(self, @@ -30,15 +31,19 @@ class UnstructuredPruner(): scope=None, place=None, prune_params_type=None, - skip_params_func=None): + skip_params_func=None, + local_sparsity=False): self.mode = mode self.ratio = ratio self.threshold = threshold + self.local_sparsity = local_sparsity + self.thresholds = {} assert self.mode in [ 'ratio', 'threshold' ], "mode must be selected from 'ratio' and 'threshold'" assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." - + if self.local_sparsity: + assert self.mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly." self.scope = paddle.static.global_scope() if scope == None else scope self.place = paddle.static.cpu_places()[0] if place is None else place @@ -156,14 +161,20 @@ class UnstructuredPruner(): continue t_param = self.scope.find_var(param).get_tensor() v_param = np.array(t_param) - params_flatten.append(v_param.flatten()) - params_flatten = np.concatenate(params_flatten, axis=0) - self.threshold = self._partition_sort(params_flatten) + if self.local_sparsity: + cur_threshold = self._partition_sort(v_param.flatten()) + self.thresholds[param] = cur_threshold + else: + params_flatten.append(v_param.flatten()) + if not self.local_sparsity: + params_flatten = np.concatenate(params_flatten, axis=0) + self.threshold = self._partition_sort(params_flatten) def _partition_sort(self, params): total_len = len(params) params_zeros = params[params == 0] params_nonzeros = params[params != 0] + if len(params_nonzeros) == 0: return 0 new_ratio = max((self.ratio * total_len - len(params_zeros)), 0) / len(params_nonzeros) return np.sort(np.abs(params_nonzeros))[max( @@ -177,7 +188,10 @@ class UnstructuredPruner(): t_param = self.scope.find_var(param).get_tensor() t_mask = self.scope.find_var(mask_name).get_tensor() v_param = np.array(t_param) - v_param[np.abs(v_param) < self.threshold] = 0 + if self.local_sparsity: + v_param[np.abs(v_param) < self.thresholds[param]] = 0 + else: + v_param[np.abs(v_param) < self.threshold] = 0 v_mask = (v_param != 0).astype(v_param.dtype) t_mask.set(v_mask, self.place) @@ -240,6 +254,7 @@ class UnstructuredPruner(): if 'norm' in op.type() and 'grad' not in op.type(): for input in op.all_inputs(): skip_params.add(input.name()) + print(skip_params) return skip_params def _get_skip_params_conv1x1(self, program): @@ -296,6 +311,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None + - local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below: .. code-block:: python @@ -317,12 +333,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): place=None, prune_params_type=None, skip_params_func=None, + local_sparsity=False, configs=None): assert configs is not None, "Please pass in a valid config dictionary." super(GMPUnstructuredPruner, self).__init__( program, 'ratio', ratio, 0.0, scope, place, prune_params_type, - skip_params_func) + skip_params_func, local_sparsity) self.stable_iterations = configs.get('stable_iterations') self.pruning_iterations = configs.get('pruning_iterations') self.tunning_iterations = configs.get('tunning_iterations') diff --git a/tests/dygraph/test_unstructured_prune.py b/tests/dygraph/test_unstructured_prune.py index 33703ecb0f7a10aafe25b78ef219f4c938facac0..2a0f634248a8816724712ee20da7518feacb76b7 100644 --- a/tests/dygraph/test_unstructured_prune.py +++ b/tests/dygraph/test_unstructured_prune.py @@ -16,12 +16,14 @@ class TestUnstructuredPruner(unittest.TestCase): def _gen_model(self): self.net = mobilenet_v1(num_classes=10, pretrained=False) self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False) - self.pruner = UnstructuredPruner(self.net, mode='ratio', ratio=0.55) + self.pruner = UnstructuredPruner( + self.net, mode='ratio', ratio=0.55, local_sparsity=True) self.pruner_conv1x1 = UnstructuredPruner( self.net_conv1x1, mode='ratio', ratio=0.55, - prune_params_type='conv1x1_only') + prune_params_type='conv1x1_only', + local_sparsity=False) def test_prune(self): ori_sparsity = UnstructuredPruner.total_sparse(self.net) diff --git a/tests/test_unstructured_pruner.py b/tests/test_unstructured_pruner.py index ac23b74af91ab6c869076a987f834b92d946dd88..2f34f25b77648772a3f626a0c6edd5c51663efa6 100644 --- a/tests/test_unstructured_pruner.py +++ b/tests/test_unstructured_pruner.py @@ -42,13 +42,18 @@ class TestUnstructuredPruner(StaticCase): exe.run(self.startup_program, scope=self.scope) self.pruner = UnstructuredPruner( - self.main_program, 'ratio', scope=self.scope, place=place) + self.main_program, + 'ratio', + scope=self.scope, + place=place, + local_sparsity=True) self.pruner_conv1x1 = UnstructuredPruner( self.main_program, 'ratio', scope=self.scope, place=place, - prune_params_type='conv1x1_only') + prune_params_type='conv1x1_only', + local_sparsity=False) def test_unstructured_prune(self): for param in self.main_program.global_block().all_parameters():