diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index 52075c9a47d34723d0f90b8c69b982a610aeb2f7..fa9690f35215fad1f54aa246e83af5119ad4c408 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -19,7 +19,7 @@ from .prune_walker import conv2d as conv2d_walker __all__ = ["collect_convs"] -def collect_convs(params, graph): +def collect_convs(params, graph, visited={}): """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. @@ -52,7 +52,6 @@ def collect_convs(params, graph): graph = GraphWrapper(graph) groups = [] for param in params: - visited = {} pruned_params = [] param = graph.var(param) conv_op = param.outputs()[0] diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 8169c56ba30cb3ad7c604fce389ef34028a4ae7e..90ff075e23cf96e3e0b23e4728297641df44a3c6 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -90,7 +90,7 @@ class Pruner(): visited = {} pruned_params = [] for param, ratio in zip(params, ratios): - group = collect_convs([param], graph)[0] # [(name, axis)] + group = collect_convs([param], graph, visited)[0] # [(name, axis)] if only_graph and self.idx_selector.__name__ == "default_idx_selector": param_v = graph.var(param) @@ -126,8 +126,9 @@ class Pruner(): pruned_axis]) param = graph.var(param_name) if not lazy: - _logger.debug("{}\t{}\t{}".format(param.name( - ), pruned_axis, len(pruned_idx))) + _logger.debug("{}\t{}\t{}\t{}".format( + param.name(), pruned_axis, + param.shape()[pruned_axis], len(pruned_idx))) if param_shape_backup is not None: origin_shape = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = origin_shape