未验证 提交 524ac561 编写于 作者: W whs 提交者: GitHub

Fix skipping leaves option (#837) (#839)

* Fix skipping leaves option
* Add unitests
上级 fceedb12
...@@ -105,6 +105,9 @@ class PruneWorker(object): ...@@ -105,6 +105,9 @@ class PruneWorker(object):
def _visit_and_search(self, var, axis, transforms): def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis) self._visit(var, axis)
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
pre_ops = var.inputs() pre_ops = var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, var, axis, transforms) self._prune_op(op, var, axis, transforms)
...@@ -123,7 +126,6 @@ class PruneWorker(object): ...@@ -123,7 +126,6 @@ class PruneWorker(object):
if op.type() in self.ops_unsupported: if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format( raise UnsupportOpError("Unsupported operator named {}".format(
op.type())) op.type()))
cls = PRUNE_WORKER.get(op.type()) cls = PRUNE_WORKER.get(op.type())
if cls is None: if cls is None:
if op.type() in SKIPPED_OPS: if op.type() in SKIPPED_OPS:
...@@ -214,10 +216,7 @@ class conv2d(PruneWorker): ...@@ -214,10 +216,7 @@ class conv2d(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx) self.append_pruned_vars(filter_var, 0, pruned_idx)
self._visit_and_search(filter_var, 0, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars( self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx) self.op.inputs("Bias")[0], channel_axis, pruned_idx)
...@@ -240,8 +239,7 @@ class conv2d_transpose(PruneWorker): ...@@ -240,8 +239,7 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx) self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs(): self._visit_and_search(filter_var, 0, pruned_idx)
self._prune_op(op, filter_var, 0, pruned_idx)
elif var in self.op.inputs("Filter"): elif var in self.op.inputs("Filter"):
_logger.warn("Skip pruning output channels of conv2d_transpose!") _logger.warn("Skip pruning output channels of conv2d_transpose!")
...@@ -252,20 +250,15 @@ class conv2d_transpose(PruneWorker): ...@@ -252,20 +250,15 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1) self._visit(filter_var, 1)
self.append_pruned_vars(filter_var, 1, pruned_idx) self.append_pruned_vars(filter_var, 1, pruned_idx)
for op in filter_var.outputs(): self._visit_and_search(filter_var, 1, pruned_idx)
self._prune_op(op, filter_var, 1, pruned_idx)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars( self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx) self.op.inputs("Bias")[0], channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs() self._visit_and_search(output_var, channel_axis, pruned_idx)
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -281,22 +274,15 @@ class batch_norm(PruneWorker): ...@@ -281,22 +274,15 @@ class batch_norm(PruneWorker):
if var in self.op.outputs("Y"): if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis) self._visit_and_search(in_var, pruned_axis, pruned_idx)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias", "Mean", "Variance"]: for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
for op in param_var.outputs(): self._visit_and_search(param_var, 0, pruned_idx)
self._prune_op(op, param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Y")[0] out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis) self._visit_and_search(out_var, pruned_axis, pruned_idx)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -475,20 +461,13 @@ class sum(PruneWorker): ...@@ -475,20 +461,13 @@ class sum(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"): for in_var in self.op.inputs("X"):
pre_ops = in_var.inputs() self._visit_and_search(in_var, pruned_axis, pruned_idx)
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"): elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"): for in_var in self.op.inputs("X"):
if in_var != var: if in_var != var:
pre_ops = in_var.inputs() self._visit_and_search(in_var, pruned_axis, pruned_idx)
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis) self._visit_and_search(out_var, pruned_axis, pruned_idx)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -756,12 +735,10 @@ class scale(PruneWorker): ...@@ -756,12 +735,10 @@ class scale(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
for op in out_var.outputs(): self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._prune_op(op, out_var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Out"): elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
for op in in_var.inputs(): self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._prune_op(op, in_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -802,22 +779,15 @@ class affine_channel(PruneWorker): ...@@ -802,22 +779,15 @@ class affine_channel(PruneWorker):
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis) self._visit_and_search(in_var, pruned_axis, pruned_idx)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias"]: for param in ["Scale", "Bias"]:
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
for op in param_var.outputs(): self._visit_and_search(param_var, 0, pruned_idx)
self._prune_op(op, param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis) self._visit_and_search(out_var, pruned_axis, pruned_idx)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -843,11 +813,8 @@ class flatten_contiguous_range(PruneWorker): ...@@ -843,11 +813,8 @@ class flatten_contiguous_range(PruneWorker):
out_pruned_axis = start_axis + pruned_axis - stop_axis out_pruned_axis = start_axis + pruned_axis - stop_axis
self._visit(in_var, pruned_axis) self._visit(in_var, pruned_axis)
self._visit(out_var, out_pruned_axis)
transform = {'stride': stride} transform = {'stride': stride}
next_ops = out_var.outputs() self._visit_and_search(out_var, out_pruned_axis,
for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform]) transforms + [transform])
......
...@@ -201,6 +201,50 @@ class TestSqueeze2(StaticCase): ...@@ -201,6 +201,50 @@ class TestSqueeze2(StaticCase):
self.assertTrue(ret == {}) self.assertTrue(ret == {})
class TestSum(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
conv2 = conv_bn_layer(
input, 8, 3, "conv2", act='relu') #[1, 8, 1, 1]
out = conv1 + conv2
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("sum")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning inputs
pruned_params = []
worker = cls(op, pruned_params, {}, True)
worker.skip_vars = [out.name]
try:
worker.prune(in_var, 0, [])
except UnsupportOpError as e:
print(e)
self.assertTrue(pruned_params == [])
class TestUnsupportAndDefault(StaticCase): class TestUnsupportAndDefault(StaticCase):
def test_prune(self): def test_prune(self):
main_program = fluid.Program() main_program = fluid.Program()
...@@ -324,7 +368,6 @@ class TestPruneWorker(unittest.TestCase): ...@@ -324,7 +368,6 @@ class TestPruneWorker(unittest.TestCase):
if var.name() not in ret: if var.name() not in ret:
ret[var.name()] = [] ret[var.name()] = []
ret[var.name()].append(axis) ret[var.name()].append(axis)
print(f"excepted: {_ret}; but get {ret}")
self.assertTrue(ret == _ret) self.assertTrue(ret == _ret)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册