未验证 提交 5ec83b62 编写于 作者: M minghaoBD 提交者: GitHub

[Unstructured_prune] add local_sparsity (#916)

上级 b9e5730b
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
UnstructuredPruner UnstructuredPruner
---------- ----------
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None) .. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
...@@ -44,6 +44,8 @@ UnstructuredPruner ...@@ -44,6 +44,8 @@ UnstructuredPruner
.. ..
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
**返回:** 一个UnstructuredPruner类的实例。 **返回:** 一个UnstructuredPruner类的实例。
**示例代码:** **示例代码:**
...@@ -203,7 +205,7 @@ GMPUnstructuredPruner ...@@ -203,7 +205,7 @@ GMPUnstructuredPruner
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, configs=None) .. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None)
该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。 该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。
...@@ -213,6 +215,7 @@ GMPUnstructuredPruner ...@@ -213,6 +215,7 @@ GMPUnstructuredPruner
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。 - **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。 - **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio' 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下: - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下:
.. code-block:: python .. code-block:: python
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
UnstrucuturedPruner UnstrucuturedPruner
---------- ----------
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None) .. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None, local_sparsity=False)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
...@@ -43,6 +43,7 @@ UnstrucuturedPruner ...@@ -43,6 +43,7 @@ UnstrucuturedPruner
.. ..
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
**返回:** 一个UnstructuredPruner类的实例 **返回:** 一个UnstructuredPruner类的实例
...@@ -280,7 +281,7 @@ UnstrucuturedPruner ...@@ -280,7 +281,7 @@ UnstrucuturedPruner
GMPUnstrucuturedPruner GMPUnstrucuturedPruner
---------- ----------
.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, configs=None) .. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, local_sparsity=False, configs=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_ `源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
...@@ -294,6 +295,7 @@ GMPUnstrucuturedPruner ...@@ -294,6 +295,7 @@ GMPUnstrucuturedPruner
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。 - **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。 - **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 - **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围:local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下: - **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下:
.. code-block:: python .. code-block:: python
......
...@@ -24,6 +24,7 @@ class UnstructuredPruner(): ...@@ -24,6 +24,7 @@ class UnstructuredPruner():
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55 - ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
""" """
def __init__(self, def __init__(self,
...@@ -32,14 +33,19 @@ class UnstructuredPruner(): ...@@ -32,14 +33,19 @@ class UnstructuredPruner():
threshold=0.01, threshold=0.01,
ratio=0.55, ratio=0.55,
prune_params_type=None, prune_params_type=None,
skip_params_func=None): skip_params_func=None,
local_sparsity=False):
assert mode in ('ratio', 'threshold' assert mode in ('ratio', 'threshold'
), "mode must be selected from 'ratio' and 'threshold'" ), "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
if local_sparsity:
assert mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
self.model = model self.model = model
self.mode = mode self.mode = mode
self.threshold = threshold self.threshold = threshold
self.ratio = ratio self.ratio = ratio
self.local_sparsity = local_sparsity
self.thresholds = {}
# Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params # Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params
if skip_params_func is not None: if skip_params_func is not None:
...@@ -89,18 +95,30 @@ class UnstructuredPruner(): ...@@ -89,18 +95,30 @@ class UnstructuredPruner():
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
t_param = param.value().get_tensor() t_param = param.value().get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
params_flatten.append(v_param.flatten()) if self.local_sparsity:
params_flatten = np.concatenate(params_flatten, axis=0) flatten_v_param = v_param.flatten()
total_length = params_flatten.size cur_length = flatten_v_param.size
self.threshold = np.sort(np.abs(params_flatten))[max( cur_threshold = np.sort(np.abs(flatten_v_param))[max(
0, round(self.ratio * total_length) - 1)].item() 0, round(self.ratio * cur_length) - 1)].item()
self.thresholds[param.name] = cur_threshold
else:
params_flatten.append(v_param.flatten())
if not self.local_sparsity:
params_flatten = np.concatenate(params_flatten, axis=0)
total_length = params_flatten.size
self.threshold = np.sort(np.abs(params_flatten))[max(
0, round(self.ratio * total_length) - 1)].item()
def _update_masks(self): def _update_masks(self):
for name, sub_layer in self.model.named_sublayers(): for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer): continue if not self._should_prune_layer(sub_layer): continue
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
mask = self.masks.get(param.name) mask = self.masks.get(param.name)
bool_tmp = (paddle.abs(param) >= self.threshold) if self.local_sparsity:
bool_tmp = (
paddle.abs(param) >= self.thresholds[param.name])
else:
bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(bool_tmp, output=mask) paddle.assign(bool_tmp, output=mask)
def summarize_weights(self, model, ratio=0.1): def summarize_weights(self, model, ratio=0.1):
...@@ -243,6 +261,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -243,6 +261,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55 - ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None
.. code-block:: python .. code-block:: python
...@@ -263,11 +282,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -263,11 +282,13 @@ class GMPUnstructuredPruner(UnstructuredPruner):
ratio=0.55, ratio=0.55,
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False,
configs=None): configs=None):
assert configs is not None, "Configs must be passed in for GMP pruner." assert configs is not None, "Configs must be passed in for GMP pruner."
super(GMPUnstructuredPruner, self).__init__( super(GMPUnstructuredPruner, self).__init__(
model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func) model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func,
local_sparsity)
self.stable_iterations = configs.get('stable_iterations') self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations') self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations') self.tunning_iterations = configs.get('tunning_iterations')
......
...@@ -20,6 +20,7 @@ class UnstructuredPruner(): ...@@ -20,6 +20,7 @@ class UnstructuredPruner():
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
""" """
def __init__(self, def __init__(self,
...@@ -30,15 +31,19 @@ class UnstructuredPruner(): ...@@ -30,15 +31,19 @@ class UnstructuredPruner():
scope=None, scope=None,
place=None, place=None,
prune_params_type=None, prune_params_type=None,
skip_params_func=None): skip_params_func=None,
local_sparsity=False):
self.mode = mode self.mode = mode
self.ratio = ratio self.ratio = ratio
self.threshold = threshold self.threshold = threshold
self.local_sparsity = local_sparsity
self.thresholds = {}
assert self.mode in [ assert self.mode in [
'ratio', 'threshold' 'ratio', 'threshold'
], "mode must be selected from 'ratio' and 'threshold'" ], "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now." assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
if self.local_sparsity:
assert self.mode == 'ratio', "We don't support local_sparsity==True and mode=='threshold' at the same time, please change the inputs accordingly."
self.scope = paddle.static.global_scope() if scope == None else scope self.scope = paddle.static.global_scope() if scope == None else scope
self.place = paddle.static.cpu_places()[0] if place is None else place self.place = paddle.static.cpu_places()[0] if place is None else place
...@@ -156,14 +161,20 @@ class UnstructuredPruner(): ...@@ -156,14 +161,20 @@ class UnstructuredPruner():
continue continue
t_param = self.scope.find_var(param).get_tensor() t_param = self.scope.find_var(param).get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
params_flatten.append(v_param.flatten()) if self.local_sparsity:
params_flatten = np.concatenate(params_flatten, axis=0) cur_threshold = self._partition_sort(v_param.flatten())
self.threshold = self._partition_sort(params_flatten) self.thresholds[param] = cur_threshold
else:
params_flatten.append(v_param.flatten())
if not self.local_sparsity:
params_flatten = np.concatenate(params_flatten, axis=0)
self.threshold = self._partition_sort(params_flatten)
def _partition_sort(self, params): def _partition_sort(self, params):
total_len = len(params) total_len = len(params)
params_zeros = params[params == 0] params_zeros = params[params == 0]
params_nonzeros = params[params != 0] params_nonzeros = params[params != 0]
if len(params_nonzeros) == 0: return 0
new_ratio = max((self.ratio * total_len - len(params_zeros)), new_ratio = max((self.ratio * total_len - len(params_zeros)),
0) / len(params_nonzeros) 0) / len(params_nonzeros)
return np.sort(np.abs(params_nonzeros))[max( return np.sort(np.abs(params_nonzeros))[max(
...@@ -177,7 +188,10 @@ class UnstructuredPruner(): ...@@ -177,7 +188,10 @@ class UnstructuredPruner():
t_param = self.scope.find_var(param).get_tensor() t_param = self.scope.find_var(param).get_tensor()
t_mask = self.scope.find_var(mask_name).get_tensor() t_mask = self.scope.find_var(mask_name).get_tensor()
v_param = np.array(t_param) v_param = np.array(t_param)
v_param[np.abs(v_param) < self.threshold] = 0 if self.local_sparsity:
v_param[np.abs(v_param) < self.thresholds[param]] = 0
else:
v_param[np.abs(v_param) < self.threshold] = 0
v_mask = (v_param != 0).astype(v_param.dtype) v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place) t_mask.set(v_mask, self.place)
...@@ -240,6 +254,7 @@ class UnstructuredPruner(): ...@@ -240,6 +254,7 @@ class UnstructuredPruner():
if 'norm' in op.type() and 'grad' not in op.type(): if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs(): for input in op.all_inputs():
skip_params.add(input.name()) skip_params.add(input.name())
print(skip_params)
return skip_params return skip_params
def _get_skip_params_conv1x1(self, program): def _get_skip_params_conv1x1(self, program):
...@@ -296,6 +311,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -296,6 +311,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None. - place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None - prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None - skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
- local_sparsity(bool): Whether to enable local sparsity. Local sparsity means all the weight matrices have the same sparsity. And the global sparsity only ensures the whole model's sparsity is equal to the passed-in 'ratio'. Default: False
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below: - configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below:
.. code-block:: python .. code-block:: python
...@@ -317,12 +333,13 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -317,12 +333,13 @@ class GMPUnstructuredPruner(UnstructuredPruner):
place=None, place=None,
prune_params_type=None, prune_params_type=None,
skip_params_func=None, skip_params_func=None,
local_sparsity=False,
configs=None): configs=None):
assert configs is not None, "Please pass in a valid config dictionary." assert configs is not None, "Please pass in a valid config dictionary."
super(GMPUnstructuredPruner, self).__init__( super(GMPUnstructuredPruner, self).__init__(
program, 'ratio', ratio, 0.0, scope, place, prune_params_type, program, 'ratio', ratio, 0.0, scope, place, prune_params_type,
skip_params_func) skip_params_func, local_sparsity)
self.stable_iterations = configs.get('stable_iterations') self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations') self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations') self.tunning_iterations = configs.get('tunning_iterations')
......
...@@ -16,12 +16,14 @@ class TestUnstructuredPruner(unittest.TestCase): ...@@ -16,12 +16,14 @@ class TestUnstructuredPruner(unittest.TestCase):
def _gen_model(self): def _gen_model(self):
self.net = mobilenet_v1(num_classes=10, pretrained=False) self.net = mobilenet_v1(num_classes=10, pretrained=False)
self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False) self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False)
self.pruner = UnstructuredPruner(self.net, mode='ratio', ratio=0.55) self.pruner = UnstructuredPruner(
self.net, mode='ratio', ratio=0.55, local_sparsity=True)
self.pruner_conv1x1 = UnstructuredPruner( self.pruner_conv1x1 = UnstructuredPruner(
self.net_conv1x1, self.net_conv1x1,
mode='ratio', mode='ratio',
ratio=0.55, ratio=0.55,
prune_params_type='conv1x1_only') prune_params_type='conv1x1_only',
local_sparsity=False)
def test_prune(self): def test_prune(self):
ori_sparsity = UnstructuredPruner.total_sparse(self.net) ori_sparsity = UnstructuredPruner.total_sparse(self.net)
......
...@@ -42,13 +42,18 @@ class TestUnstructuredPruner(StaticCase): ...@@ -42,13 +42,18 @@ class TestUnstructuredPruner(StaticCase):
exe.run(self.startup_program, scope=self.scope) exe.run(self.startup_program, scope=self.scope)
self.pruner = UnstructuredPruner( self.pruner = UnstructuredPruner(
self.main_program, 'ratio', scope=self.scope, place=place) self.main_program,
'ratio',
scope=self.scope,
place=place,
local_sparsity=True)
self.pruner_conv1x1 = UnstructuredPruner( self.pruner_conv1x1 = UnstructuredPruner(
self.main_program, self.main_program,
'ratio', 'ratio',
scope=self.scope, scope=self.scope,
place=place, place=place,
prune_params_type='conv1x1_only') prune_params_type='conv1x1_only',
local_sparsity=False)
def test_unstructured_prune(self): def test_unstructured_prune(self):
for param in self.main_program.global_block().all_parameters(): for param in self.main_program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册