未验证 提交 f3fcd836 编写于 作者: W whs 提交者: GitHub

Fix visiting strategy while pruning network (#235)

上级 5b3bd165
...@@ -19,7 +19,7 @@ from .prune_walker import conv2d as conv2d_walker ...@@ -19,7 +19,7 @@ from .prune_walker import conv2d as conv2d_walker
__all__ = ["collect_convs"] __all__ = ["collect_convs"]
def collect_convs(params, graph): def collect_convs(params, graph, visited={}):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -52,7 +52,6 @@ def collect_convs(params, graph): ...@@ -52,7 +52,6 @@ def collect_convs(params, graph):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
groups = [] groups = []
for param in params: for param in params:
visited = {}
pruned_params = [] pruned_params = []
param = graph.var(param) param = graph.var(param)
conv_op = param.outputs()[0] conv_op = param.outputs()[0]
......
...@@ -90,7 +90,7 @@ class Pruner(): ...@@ -90,7 +90,7 @@ class Pruner():
visited = {} visited = {}
pruned_params = [] pruned_params = []
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
group = collect_convs([param], graph)[0] # [(name, axis)] group = collect_convs([param], graph, visited)[0] # [(name, axis)]
if only_graph and self.idx_selector.__name__ == "default_idx_selector": if only_graph and self.idx_selector.__name__ == "default_idx_selector":
param_v = graph.var(param) param_v = graph.var(param)
...@@ -126,8 +126,9 @@ class Pruner(): ...@@ -126,8 +126,9 @@ class Pruner():
pruned_axis]) pruned_axis])
param = graph.var(param_name) param = graph.var(param_name)
if not lazy: if not lazy:
_logger.debug("{}\t{}\t{}".format(param.name( _logger.debug("{}\t{}\t{}\t{}".format(
), pruned_axis, len(pruned_idx))) param.name(), pruned_axis,
param.shape()[pruned_axis], len(pruned_idx)))
if param_shape_backup is not None: if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape()) origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape param_shape_backup[param.name()] = origin_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册