axis],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
axis],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_detail.name}; len(mask): {len(mask)}"
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...
...
@@ -153,6 +164,7 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
Returns:
list<Group>: The groups.
...
...
@@ -160,6 +172,13 @@ class PruningCollections(object):
"""
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
ifskip_varsisNone:
skip_vars=self._find_leaves(graph)
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
format(skip_vars))
visited={}
collections=[]
unsupported_warnings=set()
...
...
@@ -180,7 +199,6 @@ class PruningCollections(object):
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
break
else:
cls=PRUNE_WORKER.get(target_op.type())
ifclsisNone:
...
...
@@ -191,6 +209,7 @@ class PruningCollections(object):
A wrapper of operator used to infer the information of all the related variables.
...
...
@@ -59,6 +64,7 @@ class PruneWorker(object):
pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
skip_vars(list<str>): The variables in 'skip_vars' and their relatives will be skipped. Default: [].