未验证 提交 8bbb37ee 编写于 作者: W whs 提交者: GitHub

[cherry-pick]Fix pruning walker (#278)

上级 d3d94d15
......@@ -23,6 +23,8 @@ _logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
SKIP_OPS = ["conditional_block"]
class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}):
......@@ -72,6 +74,9 @@ class PruneWorker(object):
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
if cls is None:
if op.type() in SKIP_OPS:
_logger.warn("Skip operator [{}]".format(op.type()))
return
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))
......@@ -149,6 +154,71 @@ class conv2d(PruneWorker):
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class conv2d_transpose(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(conv2d_transpose, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if pruned_axis == 1:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 0:
input_var = self.op.inputs("Input")[0]
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited):
......@@ -267,7 +337,7 @@ class default_walker(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
for in_var in self.op.inputs():
for in_var in self.op.all_inputs():
if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
......@@ -549,3 +619,33 @@ class adam(PruneWorker):
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
@PRUNE_WORKER.register
class affine_channel(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(affine_channel, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Out")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for param in ["Scale", "Bias"]:
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -90,6 +90,11 @@ class Pruner():
visited = {}
pruned_params = []
for param, ratio in zip(params, ratios):
if graph.var(param) is None:
_logger.warn(
"Variable[{}] to be pruned is not in current graph.".
format(param))
continue
group = collect_convs([param], graph, visited)[0] # [(name, axis)]
if group is None or len(group) == 0:
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册