提交 53977ef4 编写于 作者: W wanghaoshuang

Fix backward searching in pruner.

上级 a8ffa837
...@@ -54,6 +54,9 @@ class VarWrapper(object): ...@@ -54,6 +54,9 @@ class VarWrapper(object):
""" """
return self._var.name return self._var.name
def __repr__(self):
return self._var.name
def shape(self): def shape(self):
""" """
Get the shape of the varibale. Get the shape of the varibale.
...@@ -131,6 +134,11 @@ class OpWrapper(object): ...@@ -131,6 +134,11 @@ class OpWrapper(object):
""" """
return self._op.type return self._op.type
def __repr__(self):
return "op[id: {}, type: {}; inputs: {}]".format(self.idx(),
self.type(),
self.all_inputs())
def is_bwd_op(self): def is_bwd_op(self):
""" """
Whether this operator is backward op. Whether this operator is backward op.
......
...@@ -528,33 +528,41 @@ class Pruner(): ...@@ -528,33 +528,41 @@ class Pruner():
Returns: Returns:
list<VarWrapper>: A list of operators. list<VarWrapper>: A list of operators.
""" """
_logger.debug("######################search: {}######################".
format(op_node))
visited = [op_node.idx()] visited = [op_node.idx()]
stack = [] stack = []
brothers = [] brothers = []
for op in graph.next_ops(op_node): for op in graph.next_ops(op_node):
if (op.type() != 'conv2d') and (op.type() != 'fc') and ( if ("conv2d" not in op.type()) and (op.type() != 'fc') and (
not op.is_bwd_op()): not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op) stack.append(op)
visited.append(op.idx()) visited.append(op.idx())
while len(stack) > 0: while len(stack) > 0:
top_op = stack.pop() top_op = stack.pop()
if top_op.type().startswith("elementwise_"): for parent in graph.pre_ops(top_op):
for parent in graph.pre_ops(top_op): if parent.idx() not in visited and (
if parent.idx() not in visited and ( not parent.is_bwd_op()) and (not parent.is_opt_op()):
not parent.is_bwd_op()): _logger.debug("----------go back from {} to {}----------".
if ((parent.type() == 'conv2d') or format(top_op, parent))
(parent.type() == 'fc')): if (('conv2d' in parent.type()) or
brothers.append(parent) (parent.type() == 'fc')):
else: brothers.append(parent)
stack.append(parent) else:
visited.append(parent.idx()) stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op): for child in graph.next_ops(top_op):
if (child.type() != 'conv2d') and (child.type() != 'fc') and ( if ('conv2d' not in child.type()
) and (child.type() != 'fc') and (
child.idx() not in visited) and ( child.idx() not in visited) and (
not child.is_bwd_op()): not child.is_bwd_op()) and (not child.is_opt_op()):
stack.append(child) stack.append(child)
visited.append(child.idx()) visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
_logger.debug(
"######################Finish search######################".format(
op_node))
return brothers return brothers
def _cal_pruned_idx(self, name, param, ratio, axis): def _cal_pruned_idx(self, name, param, ratio, axis):
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from prune import Pruner from paddleslim.prune import Pruner
from layers import conv_bn_layer from layers import conv_bn_layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册