未验证 提交 8eacc16b 编写于 作者: W whs 提交者: GitHub

Skip leaves in graph when pruning (#734)

上级 0baa9b36
......@@ -258,7 +258,7 @@ class FilterPruner(Pruner):
continue
if baseline is None:
baseline = eval_func()
plan = self.prune_var(var_name, dims, ratio, apply="lazy")
plan = self.prune_var(var_name, dims, ratio)
pruned_metric = eval_func()
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(
......@@ -333,13 +333,14 @@ class FilterPruner(Pruner):
src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape()
for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask
groups = _detail.op.attr('groups')
if groups is None or groups == 1:
assert len(current_mask) == var_shape[
_detail.
axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_detail.name}; len(mask): {len(mask)}"
plan.add(_detail.name,
PruningMask(_detail.axis, current_mask, pruned_ratio,
_detail.op))
......
......@@ -128,7 +128,18 @@ class PruningCollections(object):
def __iter__(self):
return iter(self._collections)
def create_pruning_collections(self, params, graph, skip_stranger=True):
def _find_leaves(self, graph):
ret = []
for _var in graph.vars():
if len(_var.outputs()) == 0:
ret.append(_var.name())
return ret
def create_pruning_collections(self,
params,
graph,
skip_stranger=True,
skip_vars=None):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
......@@ -153,6 +164,7 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None.
Returns:
list<Group>: The groups.
......@@ -160,6 +172,13 @@ class PruningCollections(object):
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
if skip_vars is None:
skip_vars = self._find_leaves(graph)
_logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.".
format(skip_vars))
visited = {}
collections = []
unsupported_warnings = set()
......@@ -180,7 +199,6 @@ class PruningCollections(object):
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
if cls is None:
......@@ -191,6 +209,7 @@ class PruningCollections(object):
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
worker.skip_vars = skip_vars
try:
visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[])
......
......@@ -50,7 +50,12 @@ class UnsupportOpError(Exception):
class PruneWorker(object):
def __init__(self, op, pruned_params, visited, skip_stranger=True):
def __init__(self,
op,
pruned_params,
visited,
skip_stranger=True,
skip_vars=[]):
"""
A wrapper of operator used to infer the information of all the related variables.
......@@ -59,6 +64,7 @@ class PruneWorker(object):
pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
skip_vars(list<str>): The variables in 'skip_vars' and their relatives will be skipped. Default: [].
Return: A instance of PruneWorker.
"""
......@@ -69,6 +75,7 @@ class PruneWorker(object):
self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None)
self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip(
).split(",")
self.skip_vars = skip_vars
def prune(self, var, pruned_axis, pruned_idx):
"""
......@@ -79,6 +86,9 @@ class PruneWorker(object):
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable.
"""
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
......@@ -130,6 +140,7 @@ class PruneWorker(object):
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
)
worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
worker.skip_vars = self.skip_vars
worker.prune(var, pruned_axis, pruned_idx)
def append_pruned_vars(self, var, axis, transforms):
......@@ -598,7 +609,7 @@ class depthwise_conv2d(PruneWorker):
_filter = self.op.inputs("Filter")[0]
_out = self.op.outputs("Output")[0]
_in_var = self.op.inputs("Input")[0]
_groups = self.op.attr("groups")
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
......@@ -608,15 +619,23 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
# pruning number of filters
self.append_pruned_vars(_filter, 0, transforms)
assert (_filter.shape()[0] % _groups == 0)
stride = _filter.shape()[0] / _groups
self.append_pruned_vars(_filter, 0, transforms + [{
"stride": stride
}])
# kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms)
self._visit_and_search(_filter, 0, transforms + [{
"stride": stride
}])
# It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms)
self._visit_and_search(_out, channel_axis, transforms + [{
"stride": stride
}])
elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
self.append_pruned_vars(_filter, 1, transforms)
......
......@@ -169,6 +169,8 @@ class Pruner():
for name, axis, pruned_idx, transforms in items:
src = pruned_idx
for trans in transforms:
if 'src_start' not in trans:
continue
src_start = trans['src_start']
src_end = trans['src_end']
src_len = src_end - src_start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册