#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
@@ -139,7 +139,8 @@ class PruningCollections(object):
params,
graph,
skip_stranger=True,
skip_vars=None):
skip_vars=None,
skip_leaves=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...
...
@@ -164,7 +165,8 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
skip_vars(list<str>): Names of variables that will be skipped. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
Returns:
list<Group>: The groups.
...
...
@@ -173,12 +175,12 @@ class PruningCollections(object):
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
ifskip_varsisNone:
skip_vars=self._find_leaves(graph)
skip_vars=[]ifskip_varsisNoneelseskip_vars
ifskip_leaves:
leaves=self._find_leaves(graph)
skip_vars.extend(leaves)
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
format(skip_vars))
"Leaves {} will be skipped when parsing graph.".format(leaves))
visited={}
collections=[]
unsupported_warnings=set()
...
...
@@ -234,7 +236,7 @@ class PruningCollections(object):