提交 1cfa65ec 编写于 作者: W wanghaoshuang

Fix pruning of concat operator.

上级 95a95672
......@@ -130,8 +130,16 @@ class Pruner():
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
try:
pruned_param = self._prune_tensor(
np.array(param_t),
pruned_idx,
pruned_axis=0,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
......@@ -171,7 +179,6 @@ class Pruner():
"""
if params[0].name() in self.pruned_list[pruned_axis]:
return
if only_graph:
pruned_num = len(pruned_idx)
for param in params:
......@@ -210,40 +217,55 @@ class Pruner():
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param):
def _forward_search_related_op(self, graph, node):
"""
Forward search operators that will be affected by pruning of param.
Args:
graph(GraphWrapper): The graph to be searched.
param(VarWrapper): The current pruned parameter.
node(VarWrapper|OpWrapper): The current pruned parameter or operator.
Returns:
list<OpWrapper>: A list of operators.
"""
assert isinstance(param, VarWrapper)
visited = {}
for op in graph.ops():
visited[op.idx()] = False
stack = []
for op in graph.ops():
if (not op.is_bwd_op()) and (param in op.all_inputs()):
stack.append(op)
visit_path = []
if isinstance(node, VarWrapper):
for op in graph.ops():
if (not op.is_bwd_op()) and (node in op.all_inputs()):
next_ops = self._get_next_unvisited_op(graph, visited, op)
# visit_path.append(op)
visited[op.idx()] = True
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
elif isinstance(node, OpWrapper):
next_ops = self._get_next_unvisited_op(graph, visited, node)
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
while len(stack) > 0:
top_op = stack[len(stack) - 1]
if visited[top_op.idx()] == False:
visit_path.append(top_op)
visited[top_op.idx()] = True
#top_op = stack[len(stack) - 1]
top_op = stack.pop(0)
next_ops = None
if top_op.type() == "conv2d" and param not in top_op.all_inputs():
if top_op.type() in ["conv2d", "deformable_conv"]:
next_ops = None
elif top_op.type() == "mul":
elif top_op.type() in ["mul", "concat"]:
next_ops = None
else:
next_ops = self._get_next_unvisited_op(graph, visited, top_op)
if next_ops == None:
stack.pop()
else:
stack += next_ops
if next_ops != None:
for op in next_ops:
if visited[op.idx()] == False:
stack.append(op)
visit_path.append(op)
visited[op.idx()] = True
return visit_path
def _get_next_unvisited_op(self, graph, visited, top_op):
......@@ -261,7 +283,7 @@ class Pruner():
for op in graph.next_ops(top_op):
if (visited[op.idx()] == False) and (not op.is_bwd_op()):
next_ops.append(op)
return next_ops if len(next_ops) > 0 else None
return next_ops
def _get_accumulator(self, graph, param):
"""
......@@ -317,7 +339,8 @@ class Pruner():
if param.name() in self.pruned_list[0]:
return
related_ops = self._forward_search_related_op(graph, param)
for op in related_ops:
_logger.debug("relate op: {};".format(op))
if ratio is None:
assert pruned_idxs is not None
self._prune_parameter_by_idx(
......@@ -339,17 +362,20 @@ class Pruner():
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
corrected_idxs = pruned_idxs[:]
self._prune_ops(related_ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup)
for idx, op in enumerate(related_ops):
if op.type() == "conv2d" and (param not in op.all_inputs()):
def _prune_ops(self, ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup):
for idx, op in enumerate(ops):
if op.type() in ["conv2d", "deformable_conv"]:
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
conv_param = in_var
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
corrected_idxs,
pruned_idxs,
pruned_axis=1,
place=place,
lazy=lazy,
......@@ -363,7 +389,7 @@ class Pruner():
self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator(
graph, conv_param),
corrected_idxs,
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
......@@ -397,7 +423,7 @@ class Pruner():
idx = []
feature_map_size = fc_input.shape()[2] * fc_input.shape()[3]
range_idx = np.array(range(feature_map_size))
for i in corrected_idxs:
for i in pruned_idxs:
idx += list(range_idx + i * feature_map_size)
corrected_idxs = idx
self._prune_parameter_by_idx(
......@@ -412,23 +438,37 @@ class Pruner():
elif op.type() == "concat":
concat_inputs = op.all_inputs()
last_op = related_ops[idx - 1]
for out_var in last_op.all_outputs():
if out_var in concat_inputs:
concat_idx = concat_inputs.index(out_var)
last_op = ops[idx - 1]
concat_idx = None
for last_op in reversed(ops):
for out_var in last_op.all_outputs():
if out_var in concat_inputs:
concat_idx = concat_inputs.index(out_var)
break
if concat_idx is not None:
break
offset = 0
for ci in range(concat_idx):
offset += concat_inputs[ci].shape()[1]
corrected_idxs = [x + offset for x in pruned_idxs]
related_ops = self._forward_search_related_op(graph, op)
for op in related_ops:
_logger.debug("concat relate op: {};".format(op))
self._prune_ops(related_ops, corrected_idxs, graph, scope,
place, lazy, only_graph, param_backup,
param_shape_backup)
elif op.type() == "batch_norm":
bn_inputs = op.all_inputs()
mean = bn_inputs[2]
in_num = len(bn_inputs)
beta = bn_inputs[0]
mean = bn_inputs[1]
alpha = bn_inputs[2]
variance = bn_inputs[3]
alpha = bn_inputs[0]
beta = bn_inputs[1]
self._prune_parameter_by_idx(
scope, [mean] + self._get_accumulator(graph, mean),
corrected_idxs,
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
......@@ -437,7 +477,7 @@ class Pruner():
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [variance] + self._get_accumulator(graph, variance),
corrected_idxs,
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
......@@ -446,7 +486,7 @@ class Pruner():
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [alpha] + self._get_accumulator(graph, alpha),
corrected_idxs,
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
......@@ -455,7 +495,7 @@ class Pruner():
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [beta] + self._get_accumulator(graph, beta),
corrected_idxs,
pruned_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
......@@ -491,6 +531,7 @@ class Pruner():
self.pruned_list = [[], []]
for param, ratio in zip(params, ratios):
assert isinstance(param, str) or isinstance(param, unicode)
_logger.info("pruning param: {}".format(param))
param = graph.var(param)
self._forward_pruning_ralated_params(
graph,
......@@ -504,9 +545,10 @@ class Pruner():
param_shape_backup=param_shape_backup)
ops = param.outputs()
for op in ops:
if op.type() == 'conv2d':
if op.type() in ['conv2d', 'deformable_conv']:
brother_ops = self._search_brother_ops(graph, op)
for broher in brother_ops:
_logger.debug("pruning brother: {}".format(broher))
for p in graph.get_param_by_op(broher):
self._forward_pruning_ralated_params(
graph,
......@@ -534,8 +576,11 @@ class Pruner():
stack = []
brothers = []
for op in graph.next_ops(op_node):
if ("conv2d" not in op.type()) and (op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
if ("conv2d" not in op.type()) and (
"concat" not in op.type()) and (
"deformable_conv" not in op.type()) and (
op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op)
visited.append(op.idx())
while len(stack) > 0:
......@@ -546,6 +591,7 @@ class Pruner():
_logger.debug("----------go back from {} to {}----------".
format(top_op, parent))
if (('conv2d' in parent.type()) or
("deformable_conv" in parent.type()) or
(parent.type() == 'fc')):
brothers.append(parent)
else:
......@@ -553,10 +599,13 @@ class Pruner():
visited.append(parent.idx())
for child in graph.next_ops(top_op):
if ('conv2d' not in child.type()
) and (child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()) and (not child.is_opt_op()):
if ('conv2d' not in child.type()) and (
"concat" not in child.type()) and (
'deformable_conv' not in child.type()) and (
child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()) and (
not child.is_opt_op()):
stack.append(child)
visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册