未验证 提交 8eacc16b 编写于 作者: W whs 提交者: GitHub

Skip leaves in graph when pruning (#734)

上级 0baa9b36
...@@ -258,7 +258,7 @@ class FilterPruner(Pruner): ...@@ -258,7 +258,7 @@ class FilterPruner(Pruner):
continue continue
if baseline is None: if baseline is None:
baseline = eval_func() baseline = eval_func()
plan = self.prune_var(var_name, dims, ratio, apply="lazy") plan = self.prune_var(var_name, dims, ratio)
pruned_metric = eval_func() pruned_metric = eval_func()
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format( _logger.info("pruned param: {}; {}; loss={}".format(
...@@ -333,13 +333,14 @@ class FilterPruner(Pruner): ...@@ -333,13 +333,14 @@ class FilterPruner(Pruner):
src_mask = copy.deepcopy(mask) src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape() var_shape = _detail.var.shape()
for tran in _detail.transform: for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran) src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask current_mask = src_mask
groups = _detail.op.attr('groups') groups = _detail.op.attr('groups')
if groups is None or groups == 1: if groups is None or groups == 1:
assert len(current_mask) == var_shape[ assert len(current_mask) == var_shape[
_detail. _detail.
axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}" axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_detail.name}; len(mask): {len(mask)}"
plan.add(_detail.name, plan.add(_detail.name,
PruningMask(_detail.axis, current_mask, pruned_ratio, PruningMask(_detail.axis, current_mask, pruned_ratio,
_detail.op)) _detail.op))
......
...@@ -128,7 +128,18 @@ class PruningCollections(object): ...@@ -128,7 +128,18 @@ class PruningCollections(object):
def __iter__(self): def __iter__(self):
return iter(self._collections) return iter(self._collections)
def create_pruning_collections(self, params, graph, skip_stranger=True): def _find_leaves(self, graph):
ret = []
for _var in graph.vars():
if len(_var.outputs()) == 0:
ret.append(_var.name())
return ret
def create_pruning_collections(self,
params,
graph,
skip_stranger=True,
skip_vars=None):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -153,6 +164,7 @@ class PruningCollections(object): ...@@ -153,6 +164,7 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True. skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
Returns: Returns:
list<Group>: The groups. list<Group>: The groups.
...@@ -160,6 +172,13 @@ class PruningCollections(object): ...@@ -160,6 +172,13 @@ class PruningCollections(object):
""" """
if not isinstance(graph, GraphWrapper): if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
if skip_vars is None:
skip_vars = self._find_leaves(graph)
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
format(skip_vars))
visited = {} visited = {}
collections = [] collections = []
unsupported_warnings = set() unsupported_warnings = set()
...@@ -180,7 +199,6 @@ class PruningCollections(object): ...@@ -180,7 +199,6 @@ class PruningCollections(object):
pruned_params=pruned_params, pruned_params=pruned_params,
visited=visited, visited=visited,
skip_stranger=skip_stranger) skip_stranger=skip_stranger)
break
else: else:
cls = PRUNE_WORKER.get(target_op.type()) cls = PRUNE_WORKER.get(target_op.type())
if cls is None: if cls is None:
...@@ -191,6 +209,7 @@ class PruningCollections(object): ...@@ -191,6 +209,7 @@ class PruningCollections(object):
pruned_params=pruned_params, pruned_params=pruned_params,
visited=visited, visited=visited,
skip_stranger=skip_stranger) skip_stranger=skip_stranger)
worker.skip_vars = skip_vars
try: try:
visited_backup = copy.deepcopy(worker.visited) visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[]) worker.prune(param, pruned_axis=0, pruned_idx=[])
......
...@@ -50,7 +50,12 @@ class UnsupportOpError(Exception): ...@@ -50,7 +50,12 @@ class UnsupportOpError(Exception):
class PruneWorker(object): class PruneWorker(object):
def __init__(self, op, pruned_params, visited, skip_stranger=True): def __init__(self,
op,
pruned_params,
visited,
skip_stranger=True,
skip_vars=[]):
""" """
A wrapper of operator used to infer the information of all the related variables. A wrapper of operator used to infer the information of all the related variables.
...@@ -59,6 +64,7 @@ class PruneWorker(object): ...@@ -59,6 +64,7 @@ class PruneWorker(object):
pruned_params(list): The list to store the information of pruning that infered by worker. pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name. visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True. skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
skip_vars(list<str>): The variables in 'skip_vars' and their relatives will be skipped. Default: [].
Return: A instance of PruneWorker. Return: A instance of PruneWorker.
""" """
...@@ -69,6 +75,7 @@ class PruneWorker(object): ...@@ -69,6 +75,7 @@ class PruneWorker(object):
self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None) self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None)
self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip( self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip(
).split(",") ).split(",")
self.skip_vars = skip_vars
def prune(self, var, pruned_axis, pruned_idx): def prune(self, var, pruned_axis, pruned_idx):
""" """
...@@ -79,6 +86,9 @@ class PruneWorker(object): ...@@ -79,6 +86,9 @@ class PruneWorker(object):
pruned_axis(int): The axis to be pruned of root variable. pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable. pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable.
""" """
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
if self._visit(var, pruned_axis): if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx) self._prune(var, pruned_axis, pruned_idx)
...@@ -130,6 +140,7 @@ class PruneWorker(object): ...@@ -130,6 +140,7 @@ class PruneWorker(object):
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n" f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
) )
worker = cls(op, self.pruned_params, self.visited, self.skip_stranger) worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
worker.skip_vars = self.skip_vars
worker.prune(var, pruned_axis, pruned_idx) worker.prune(var, pruned_axis, pruned_idx)
def append_pruned_vars(self, var, axis, transforms): def append_pruned_vars(self, var, axis, transforms):
...@@ -598,7 +609,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -598,7 +609,7 @@ class depthwise_conv2d(PruneWorker):
_filter = self.op.inputs("Filter")[0] _filter = self.op.inputs("Filter")[0]
_out = self.op.outputs("Output")[0] _out = self.op.outputs("Output")[0]
_in_var = self.op.inputs("Input")[0] _in_var = self.op.inputs("Input")[0]
_groups = self.op.attr("groups")
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
...@@ -608,15 +619,23 @@ class depthwise_conv2d(PruneWorker): ...@@ -608,15 +619,23 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis) pruned_axis)
# pruning number of filters # pruning number of filters
self.append_pruned_vars(_filter, 0, transforms) assert (_filter.shape()[0] % _groups == 0)
stride = _filter.shape()[0] / _groups
self.append_pruned_vars(_filter, 0, transforms + [{
"stride": stride
}])
# kernel_number * groups will be pruned by reducing groups # kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms) self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms) self._visit_and_search(_filter, 0, transforms + [{
"stride": stride
}])
# It will not pruning number of kernels in depthwise conv2d, # It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators. # so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms) # self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1) self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms) self._visit_and_search(_out, channel_axis, transforms + [{
"stride": stride
}])
elif var == _filter: elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
self.append_pruned_vars(_filter, 1, transforms) self.append_pruned_vars(_filter, 1, transforms)
......
...@@ -169,6 +169,8 @@ class Pruner(): ...@@ -169,6 +169,8 @@ class Pruner():
for name, axis, pruned_idx, transforms in items: for name, axis, pruned_idx, transforms in items:
src = pruned_idx src = pruned_idx
for trans in transforms: for trans in transforms:
if 'src_start' not in trans:
continue
src_start = trans['src_start'] src_start = trans['src_start']
src_end = trans['src_end'] src_end = trans['src_end']
src_len = src_end - src_start src_len = src_end - src_start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册