未验证 提交 09fba1bf 编写于 作者: W whs 提交者: GitHub

Fix skipping leaves option (#837)

* Fix skipping leaves option
* Add unitests
上级 b521fd55
......@@ -105,6 +105,9 @@ class PruneWorker(object):
def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis)
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
pre_ops = var.inputs()
for op in pre_ops:
self._prune_op(op, var, axis, transforms)
......@@ -123,7 +126,6 @@ class PruneWorker(object):
if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
cls = PRUNE_WORKER.get(op.type())
if cls is None:
if op.type() in SKIPPED_OPS:
......@@ -214,10 +216,7 @@ class conv2d(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
self._visit_and_search(filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
......@@ -240,8 +239,7 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
self._visit_and_search(filter_var, 0, pruned_idx)
elif var in self.op.inputs("Filter"):
_logger.warn("Skip pruning output channels of conv2d_transpose!")
......@@ -252,20 +250,15 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.append_pruned_vars(filter_var, 1, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
self._visit_and_search(filter_var, 1, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
self._visit_and_search(output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -281,22 +274,15 @@ class batch_norm(PruneWorker):
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self._visit_and_search(param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -475,20 +461,13 @@ class sum(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"):
if in_var != var:
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -756,12 +735,10 @@ class scale(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
for op in out_var.outputs():
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
for op in in_var.inputs():
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -802,22 +779,15 @@ class affine_channel(PruneWorker):
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self._visit_and_search(param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -843,12 +813,9 @@ class flatten_contiguous_range(PruneWorker):
out_pruned_axis = start_axis + pruned_axis - stop_axis
self._visit(in_var, pruned_axis)
self._visit(out_var, out_pruned_axis)
transform = {'stride': stride}
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform])
self._visit_and_search(out_var, out_pruned_axis,
transforms + [transform])
@PRUNE_WORKER.register
......
......@@ -201,6 +201,50 @@ class TestSqueeze2(StaticCase):
self.assertTrue(ret == {})
class TestSum(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
conv2 = conv_bn_layer(
input, 8, 3, "conv2", act='relu') #[1, 8, 1, 1]
out = conv1 + conv2
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("sum")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning inputs
pruned_params = []
worker = cls(op, pruned_params, {}, True)
worker.skip_vars = [out.name]
try:
worker.prune(in_var, 0, [])
except UnsupportOpError as e:
print(e)
self.assertTrue(pruned_params == [])
class TestUnsupportAndDefault(StaticCase):
def test_prune(self):
main_program = fluid.Program()
......@@ -324,7 +368,6 @@ class TestPruneWorker(unittest.TestCase):
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
print(f"excepted: {_ret}; but get {ret}")
self.assertTrue(ret == _ret)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册