axis],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
axis],f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_detail.name}; len(mask): {len(mask)}"
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...
@@ -153,6 +164,7 @@ class PruningCollections(object):
...
@@ -153,6 +164,7 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
Returns:
Returns:
list<Group>: The groups.
list<Group>: The groups.
...
@@ -160,6 +172,13 @@ class PruningCollections(object):
...
@@ -160,6 +172,13 @@ class PruningCollections(object):
"""
"""
ifnotisinstance(graph,GraphWrapper):
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
graph=GraphWrapper(graph)
ifskip_varsisNone:
skip_vars=self._find_leaves(graph)
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
format(skip_vars))
visited={}
visited={}
collections=[]
collections=[]
unsupported_warnings=set()
unsupported_warnings=set()
...
@@ -180,7 +199,6 @@ class PruningCollections(object):
...
@@ -180,7 +199,6 @@ class PruningCollections(object):
pruned_params=pruned_params,
pruned_params=pruned_params,
visited=visited,
visited=visited,
skip_stranger=skip_stranger)
skip_stranger=skip_stranger)
break
else:
else:
cls=PRUNE_WORKER.get(target_op.type())
cls=PRUNE_WORKER.get(target_op.type())
ifclsisNone:
ifclsisNone:
...
@@ -191,6 +209,7 @@ class PruningCollections(object):
...
@@ -191,6 +209,7 @@ class PruningCollections(object):
A wrapper of operator used to infer the information of all the related variables.
A wrapper of operator used to infer the information of all the related variables.
...
@@ -59,6 +64,7 @@ class PruneWorker(object):
...
@@ -59,6 +64,7 @@ class PruneWorker(object):
pruned_params(list): The list to store the information of pruning that infered by worker.
pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
skip_vars(list<str>): The variables in 'skip_vars' and their relatives will be skipped. Default: [].