f"Can not find collection with master ['name': {var_name}, 'axis': {pruned_axis}]"
)
returnplan
_logger.info(
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
)
mask=self.cal_mask(pruned_ratio,collection)
for_detailincollection.all_pruning_details():
# Varibales can be pruned on multiple axies.
for_itemingroup_dict[_name]:
src_mask=copy.deepcopy(mask)
dims=_item['pruned_dims']
transforms=_item['transforms']
var_shape=_item['var'].shape
ifisinstance(dims,int):
dims=[dims]
fortransintransforms:
src_mask=self._transform_mask(src_mask,trans)
current_mask=src_mask
assertlen(current_mask)==var_shape[dims[
0]],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
axis],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
Returns:
list<Group>: The groups.
"""
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
visited={}
collections=[]
unsupported_warnings=set()
for_paraminparams:
pruned_params=[]
param=graph.var(_param)
ifparamisNone:
_logger.warning(
f"Couldn't find relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correct mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
continue
target_op=param.outputs()[0]
iftarget_op.type()=='conditional_block':
foropinparam.outputs():
ifop.type()inPRUNE_WORKER._module_dict.keys():
cls=PRUNE_WORKER.get(op.type())
worker=cls(op,
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
break
else:
cls=PRUNE_WORKER.get(target_op.type())
ifclsisNone:
_logger.warning("No worker for operator: {}".format(
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
importlogging
from..coreimportGraphWrapper
from..commonimportget_logger
from.prune_walkerimportPRUNE_WORKER
__all__=["collect_convs"]
_logger=get_logger(__name__,level=logging.INFO)
defcollect_convs(params,graph,visited={}):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
"""
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
groups=[]
for_paraminparams:
pruned_params=[]
param=graph.var(_param)
ifparamisNone:
_logger.warning(
f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
groups.append([])
continue
target_op=param.outputs()[0]
iftarget_op.type()=='conditional_block':
foropinparam.outputs():
ifop.type()inPRUNE_WORKER._module_dict.keys():
cls=PRUNE_WORKER.get(op.type())
walker=cls(op,
pruned_params=pruned_params,
visited=visited)
break
else:
cls=PRUNE_WORKER.get(target_op.type())
ifclsisNone:
_logger.info("No walker for operator: {}".format(target_op.type(
A wrapper of operator used to infer the information of all the related variables.
Args:
op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker.
pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
"""
self.pruned_list=[]
graph=GraphWrapper(program.clone())
param_backup={}ifparam_backupelseNone
param_shape_backup={}ifparam_shape_backupelseNone
pruned_params=[]
visited={}
forparam,ratioinzip(params,ratios):
_logger.info("pruning: {}".format(param))
ifgraph.var(param)isNone:
_logger.warn(
"Variable[{}] to be pruned is not in current graph.".format(
param))
continue
group=collect_convs([param],graph,
visited)[0]# [(name, axis, pruned_idx)]
ifgroupisNoneorlen(group)==0:
continue
assert(
notself.pruned_weights),"The weights have been pruned once."