提交 53977ef4 编写于 作者: W wanghaoshuang

Fix backward searching in pruner.

上级 a8ffa837
......@@ -54,6 +54,9 @@ class VarWrapper(object):
"""
return self._var.name
def __repr__(self):
return self._var.name
def shape(self):
"""
Get the shape of the varibale.
......@@ -131,6 +134,11 @@ class OpWrapper(object):
"""
return self._op.type
def __repr__(self):
return "op[id: {}, type: {}; inputs: {}]".format(self.idx(),
self.type(),
self.all_inputs())
def is_bwd_op(self):
"""
Whether this operator is backward op.
......
......@@ -528,33 +528,41 @@ class Pruner():
Returns:
list<VarWrapper>: A list of operators.
"""
_logger.debug("######################search: {}######################".
format(op_node))
visited = [op_node.idx()]
stack = []
brothers = []
for op in graph.next_ops(op_node):
if (op.type() != 'conv2d') and (op.type() != 'fc') and (
not op.is_bwd_op()):
if ("conv2d" not in op.type()) and (op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op)
visited.append(op.idx())
while len(stack) > 0:
top_op = stack.pop()
if top_op.type().startswith("elementwise_"):
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (
not parent.is_bwd_op()):
if ((parent.type() == 'conv2d') or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (
not parent.is_bwd_op()) and (not parent.is_opt_op()):
_logger.debug("----------go back from {} to {}----------".
format(top_op, parent))
if (('conv2d' in parent.type()) or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op):
if (child.type() != 'conv2d') and (child.type() != 'fc') and (
if ('conv2d' not in child.type()
) and (child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()):
not child.is_bwd_op()) and (not child.is_opt_op()):
stack.append(child)
visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
_logger.debug(
"######################Finish search######################".format(
op_node))
return brothers
def _cal_pruned_idx(self, name, param, ratio, axis):
......
......@@ -15,7 +15,7 @@ import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from prune import Pruner
from paddleslim.prune import Pruner
from layers import conv_bn_layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册