未验证 提交 5312f46e 编写于 作者: W whs 提交者: GitHub

Fix prune worker. (#27)

上级 790a9ffb
......@@ -49,14 +49,18 @@ class PruneWorker(object):
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()])
if pruned_axis not in self.visited:
self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]:
return
return False
else:
self.visited[pruned_axis][key] = True
self._prune(var, pruned_axis, pruned_idx)
return True
def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
......@@ -83,7 +87,7 @@ class conv2d(PruneWorker):
super(conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
......@@ -91,8 +95,7 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[1][key] = True
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
......@@ -110,16 +113,14 @@ class conv2d(PruneWorker):
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), input_var.name()])
self.visited[channel_axis][key] = True
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
......@@ -128,8 +129,7 @@ class conv2d(PruneWorker):
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
......@@ -158,8 +158,7 @@ class batch_norm(PruneWorker):
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[pruned_axis][key] = True
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
......@@ -171,8 +170,7 @@ class batch_norm(PruneWorker):
self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Y")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -214,8 +212,7 @@ class elementwise_op(PruneWorker):
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -253,8 +250,7 @@ class activation(PruneWorker):
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs(self.output_name)[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -317,8 +313,7 @@ class sum(PruneWorker):
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -363,8 +358,7 @@ class concat(PruneWorker):
start += v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
......@@ -373,8 +367,7 @@ class concat(PruneWorker):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()])
self.visited[pruned_axis][key] = True
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -386,7 +379,7 @@ class depthwise_conv2d(PruneWorker):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format")
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
......@@ -396,8 +389,7 @@ class depthwise_conv2d(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups)
......@@ -425,8 +417,7 @@ class depthwise_conv2d(PruneWorker):
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()])
self.visited[channel_axis][key] = True
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
......@@ -436,8 +427,7 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()])
self.visited[0][key] = True
self._visit(filter_var, 0)
new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups)
......@@ -450,8 +440,7 @@ class depthwise_conv2d(PruneWorker):
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), in_var.name()])
self.visited[channel_axis][key] = True
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册