未验证 提交 f3fcd836 编写于 作者: W whs 提交者: GitHub

Fix visiting strategy while pruning network (#235)

上级 5b3bd165
......@@ -19,7 +19,7 @@ from .prune_walker import conv2d as conv2d_walker
__all__ = ["collect_convs"]
def collect_convs(params, graph):
def collect_convs(params, graph, visited={}):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
......@@ -52,7 +52,6 @@ def collect_convs(params, graph):
graph = GraphWrapper(graph)
groups = []
for param in params:
visited = {}
pruned_params = []
param = graph.var(param)
conv_op = param.outputs()[0]
......
......@@ -90,7 +90,7 @@ class Pruner():
visited = {}
pruned_params = []
for param, ratio in zip(params, ratios):
group = collect_convs([param], graph)[0] # [(name, axis)]
group = collect_convs([param], graph, visited)[0] # [(name, axis)]
if only_graph and self.idx_selector.__name__ == "default_idx_selector":
param_v = graph.var(param)
......@@ -126,8 +126,9 @@ class Pruner():
pruned_axis])
param = graph.var(param_name)
if not lazy:
_logger.debug("{}\t{}\t{}".format(param.name(
), pruned_axis, len(pruned_idx)))
_logger.debug("{}\t{}\t{}\t{}".format(
param.name(), pruned_axis,
param.shape()[pruned_axis], len(pruned_idx)))
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册