未验证 提交 5312f46e 编写于 作者: W whs 提交者: GitHub

Fix prune worker. (#27)

上级 790a9ffb
...@@ -49,14 +49,18 @@ class PruneWorker(object): ...@@ -49,14 +49,18 @@ class PruneWorker(object):
pruned_axis(int): The axis to be pruned of root variable. pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable. pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
""" """
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()]) key = "_".join([str(self.op.idx()), var.name()])
if pruned_axis not in self.visited: if pruned_axis not in self.visited:
self.visited[pruned_axis] = {} self.visited[pruned_axis] = {}
if key in self.visited[pruned_axis]: if key in self.visited[pruned_axis]:
return return False
else: else:
self.visited[pruned_axis][key] = True self.visited[pruned_axis][key] = True
self._prune(var, pruned_axis, pruned_idx) return True
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
...@@ -83,7 +87,7 @@ class conv2d(PruneWorker): ...@@ -83,7 +87,7 @@ class conv2d(PruneWorker):
super(conv2d, self).__init__(op, pruned_params, visited) super(conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format") data_format = self.op.attr("data_format")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
...@@ -91,8 +95,7 @@ class conv2d(PruneWorker): ...@@ -91,8 +95,7 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()]) self._visit(filter_var, 1)
self.visited[1][key] = True
self.pruned_params.append((filter_var, 1, pruned_idx)) self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx) self._prune_op(op, filter_var, 1, pruned_idx)
...@@ -110,16 +113,14 @@ class conv2d(PruneWorker): ...@@ -110,16 +113,14 @@ class conv2d(PruneWorker):
self.pruned_params.append( self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx)) (self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()]) self._visit(output_var, channel_axis)
self.visited[channel_axis][key] = True
next_ops = output_var.outputs() next_ops = output_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx) self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1: elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0] input_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), input_var.name()]) self._visit(input_var, channel_axis)
self.visited[channel_axis][key] = True
pre_ops = input_var.inputs() pre_ops = input_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx) self._prune_op(op, input_var, channel_axis, pruned_idx)
...@@ -128,8 +129,7 @@ class conv2d(PruneWorker): ...@@ -128,8 +129,7 @@ class conv2d(PruneWorker):
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
key = "_".join([str(self.op.idx()), filter_var.name()]) self._visit(filter_var, 0)
self.visited[0][key] = True
self.pruned_params.append((filter_var, 0, pruned_idx)) self.pruned_params.append((filter_var, 0, pruned_idx))
...@@ -158,8 +158,7 @@ class batch_norm(PruneWorker): ...@@ -158,8 +158,7 @@ class batch_norm(PruneWorker):
if var in self.op.outputs("Y"): if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
key = "_".join([str(self.op.idx()), in_var.name()]) self._visit(in_var, pruned_axis)
self.visited[pruned_axis][key] = True
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx) self._prune_op(op, in_var, pruned_axis, pruned_idx)
...@@ -171,8 +170,7 @@ class batch_norm(PruneWorker): ...@@ -171,8 +170,7 @@ class batch_norm(PruneWorker):
self.pruned_params.append((param_var, 0, pruned_idx)) self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Y")[0] out_var = self.op.outputs("Y")[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
...@@ -214,8 +212,7 @@ class elementwise_op(PruneWorker): ...@@ -214,8 +212,7 @@ class elementwise_op(PruneWorker):
self._prune_op(op, in_var, pruned_axis, pruned_idx) self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
...@@ -253,8 +250,7 @@ class activation(PruneWorker): ...@@ -253,8 +250,7 @@ class activation(PruneWorker):
self._prune_op(op, in_var, pruned_axis, pruned_idx) self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs(self.output_name)[0] out_var = self.op.outputs(self.output_name)[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
...@@ -317,8 +313,7 @@ class sum(PruneWorker): ...@@ -317,8 +313,7 @@ class sum(PruneWorker):
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx) self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
...@@ -363,8 +358,7 @@ class concat(PruneWorker): ...@@ -363,8 +358,7 @@ class concat(PruneWorker):
start += v.shape()[pruned_axis] start += v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={}) self._prune_op(op, out_var, pruned_axis, idx, visited={})
...@@ -373,8 +367,7 @@ class concat(PruneWorker): ...@@ -373,8 +367,7 @@ class concat(PruneWorker):
for op in v.inputs(): for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx) self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
key = "_".join([str(self.op.idx()), out_var.name()]) self._visit(out_var, pruned_axis)
self.visited[pruned_axis][key] = True
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
...@@ -386,7 +379,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -386,7 +379,7 @@ class depthwise_conv2d(PruneWorker):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited) super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
data_format = sef.op.attr("data_format") data_format = self.op.attr("data_format")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
...@@ -396,8 +389,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -396,8 +389,7 @@ class depthwise_conv2d(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx)) self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()]) self._visit(filter_var, 0)
self.visited[0][key] = True
new_groups = filter_var.shape()[0] - len(pruned_idx) new_groups = filter_var.shape()[0] - len(pruned_idx)
self.op.set_attr("groups", new_groups) self.op.set_attr("groups", new_groups)
...@@ -425,8 +417,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -425,8 +417,7 @@ class depthwise_conv2d(PruneWorker):
self._prune_op(op, var, 0, pruned_idx) self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
key = "_".join([str(self.op.idx()), output_var.name()]) self._visit(output_var, channel_axis)
self.visited[channel_axis][key] = True
next_ops = output_var.outputs() next_ops = output_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx) self._prune_op(op, output_var, channel_axis, pruned_idx)
...@@ -436,8 +427,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -436,8 +427,7 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx)) self.pruned_params.append((filter_var, 0, pruned_idx))
key = "_".join([str(self.op.idx()), filter_var.name()]) self._visit(filter_var, 0)
self.visited[0][key] = True
new_groups = filter_var.shape()[0] - len(pruned_idx) new_groups = filter_var.shape()[0] - len(pruned_idx)
op.set_attr("groups", new_groups) op.set_attr("groups", new_groups)
...@@ -450,8 +440,7 @@ class depthwise_conv2d(PruneWorker): ...@@ -450,8 +440,7 @@ class depthwise_conv2d(PruneWorker):
(self.op.inputs("Bias")[0], channel_axis, pruned_idx)) (self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0] in_var = self.op.inputs("Input")[0]
key = "_".join([str(self.op.idx()), in_var.name()]) self._visit(in_var, channel_axis)
self.visited[channel_axis][key] = True
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx) self._prune_op(op, in_var, channel_axis, pruned_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册