@@ -139,7 +139,8 @@ class PruningCollections(object):
...
@@ -139,7 +139,8 @@ class PruningCollections(object):
params,
params,
graph,
graph,
skip_stranger=True,
skip_stranger=True,
skip_vars=None):
skip_vars=None,
skip_leaves=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...
@@ -164,7 +165,8 @@ class PruningCollections(object):
...
@@ -164,7 +165,8 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
skip_vars(list<str>): Names of variables that will be skipped. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
Returns:
Returns:
list<Group>: The groups.
list<Group>: The groups.
...
@@ -173,12 +175,12 @@ class PruningCollections(object):
...
@@ -173,12 +175,12 @@ class PruningCollections(object):
ifnotisinstance(graph,GraphWrapper):
ifnotisinstance(graph,GraphWrapper):
graph=GraphWrapper(graph)
graph=GraphWrapper(graph)
ifskip_varsisNone:
skip_vars=[]ifskip_varsisNoneelseskip_vars
skip_vars=self._find_leaves(graph)
ifskip_leaves:
leaves=self._find_leaves(graph)
skip_vars.extend(leaves)
_logger.warning(
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
"Leaves {} will be skipped when parsing graph.".format(leaves))
format(skip_vars))
visited={}
visited={}
collections=[]
collections=[]
unsupported_warnings=set()
unsupported_warnings=set()
...
@@ -234,7 +236,7 @@ class PruningCollections(object):
...
@@ -234,7 +236,7 @@ class PruningCollections(object):