Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None".
Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None".
...
@@ -88,7 +98,7 @@ class FilterPruner(Pruner):
...
@@ -88,7 +98,7 @@ class FilterPruner(Pruner):
eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None.
eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None.
sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None.
sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None.
target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None.
target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. Default: [].
Returns:
Returns:
dict: A dict storing sensitivities.
dict: A dict storing sensitivities.
...
@@ -102,6 +112,7 @@ class FilterPruner(Pruner):
...
@@ -102,6 +112,7 @@ class FilterPruner(Pruner):
ifnotself._status.is_ckp:
ifnotself._status.is_ckp:
returnself._status
returnself._status
skip_vars.extend(self.skip_vars)
self._cal_sensitive(
self._cal_sensitive(
self.model,
self.model,
eval_func,
eval_func,
...
@@ -186,9 +197,9 @@ class FilterPruner(Pruner):
...
@@ -186,9 +197,9 @@ class FilterPruner(Pruner):
Returns:
Returns:
tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model.
tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model.
0]],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"