提交 a3b9d118 编写于 作者: C chenluyan

(1)add ops for prune_walker;(2)fix bugs of pruning mul

上级 241d4c8a
......@@ -550,7 +550,7 @@ class mul(PruneWorker):
self.pruned_params.append((param_var, 0, idx))
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self._prune_op(op, param_var, 0, idx)
@PRUNE_WORKER.register
......@@ -623,3 +623,57 @@ class affine_channel(PruneWorker):
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class roi_align(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(roi_align, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class lod_reset(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(lod_reset, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class gather(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(gather, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册