diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 72de894a2e4345c32e7a4eee2f35249b77c2f467..dc01846a10feb8bf212f9e35b9cd585df47ba739 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -54,6 +54,9 @@ class VarWrapper(object): """ return self._var.name + def __repr__(self): + return self._var.name + def shape(self): """ Get the shape of the varibale. @@ -131,6 +134,11 @@ class OpWrapper(object): """ return self._op.type + def __repr__(self): + return "op[id: {}, type: {}; inputs: {}]".format(self.idx(), + self.type(), + self.all_inputs()) + def is_bwd_op(self): """ Whether this operator is backward op. diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 0fdde525a793b90df63f3245ac5215365dd7ccf4..e2b6a7e1d28078abef97c5fa53b215b098f18cca 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -528,33 +528,41 @@ class Pruner(): Returns: list: A list of operators. """ + _logger.debug("######################search: {}######################". + format(op_node)) visited = [op_node.idx()] stack = [] brothers = [] for op in graph.next_ops(op_node): - if (op.type() != 'conv2d') and (op.type() != 'fc') and ( - not op.is_bwd_op()): + if ("conv2d" not in op.type()) and (op.type() != 'fc') and ( + not op.is_bwd_op()) and (not op.is_opt_op()): stack.append(op) visited.append(op.idx()) while len(stack) > 0: top_op = stack.pop() - if top_op.type().startswith("elementwise_"): - for parent in graph.pre_ops(top_op): - if parent.idx() not in visited and ( - not parent.is_bwd_op()): - if ((parent.type() == 'conv2d') or - (parent.type() == 'fc')): - brothers.append(parent) - else: - stack.append(parent) - visited.append(parent.idx()) + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and ( + not parent.is_bwd_op()) and (not parent.is_opt_op()): + _logger.debug("----------go back from {} to {}----------". + format(top_op, parent)) + if (('conv2d' in parent.type()) or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) for child in graph.next_ops(top_op): - if (child.type() != 'conv2d') and (child.type() != 'fc') and ( + if ('conv2d' not in child.type() + ) and (child.type() != 'fc') and ( child.idx() not in visited) and ( - not child.is_bwd_op()): + not child.is_bwd_op()) and (not child.is_opt_op()): stack.append(child) visited.append(child.idx()) + _logger.debug("brothers: {}".format(brothers)) + _logger.debug( + "######################Finish search######################".format( + op_node)) return brothers def _cal_pruned_idx(self, name, param, ratio, axis): diff --git a/tests/test_prune.py b/tests/test_prune.py index 93609367351618ce375f164a1dca284e85369e4c..3fdaa867e350af876648871f83fe70cc83b548b6 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -15,7 +15,7 @@ import sys sys.path.append("../") import unittest import paddle.fluid as fluid -from prune import Pruner +from paddleslim.prune import Pruner from layers import conv_bn_layer