提交 cab7f5c0 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

Fix pruning of concat operator.

See merge request !63
...@@ -130,8 +130,16 @@ class Pruner(): ...@@ -130,8 +130,16 @@ class Pruner():
param.name() not in param_backup): param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy( param_backup[param.name()] = copy.deepcopy(
np.array(param_t)) np.array(param_t))
pruned_param = self._prune_tensor( try:
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) pruned_param = self._prune_tensor(
np.array(param_t),
pruned_idx,
pruned_axis=0,
lazy=lazy)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(param.name(
), e))
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
...@@ -171,7 +179,6 @@ class Pruner(): ...@@ -171,7 +179,6 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[pruned_axis]: if params[0].name() in self.pruned_list[pruned_axis]:
return return
if only_graph: if only_graph:
pruned_num = len(pruned_idx) pruned_num = len(pruned_idx)
for param in params: for param in params:
...@@ -210,40 +217,55 @@ class Pruner(): ...@@ -210,40 +217,55 @@ class Pruner():
), ori_shape, new_shape)) ), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name()) self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param): def _forward_search_related_op(self, graph, node):
""" """
Forward search operators that will be affected by pruning of param. Forward search operators that will be affected by pruning of param.
Args: Args:
graph(GraphWrapper): The graph to be searched. graph(GraphWrapper): The graph to be searched.
param(VarWrapper): The current pruned parameter. node(VarWrapper|OpWrapper): The current pruned parameter or operator.
Returns: Returns:
list<OpWrapper>: A list of operators. list<OpWrapper>: A list of operators.
""" """
assert isinstance(param, VarWrapper)
visited = {} visited = {}
for op in graph.ops(): for op in graph.ops():
visited[op.idx()] = False visited[op.idx()] = False
stack = [] stack = []
for op in graph.ops():
if (not op.is_bwd_op()) and (param in op.all_inputs()):
stack.append(op)
visit_path = [] visit_path = []
if isinstance(node, VarWrapper):
for op in graph.ops():
if (not op.is_bwd_op()) and (node in op.all_inputs()):
next_ops = self._get_next_unvisited_op(graph, visited, op)
# visit_path.append(op)
visited[op.idx()] = True
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
elif isinstance(node, OpWrapper):
next_ops = self._get_next_unvisited_op(graph, visited, node)
for next_op in next_ops:
if visited[next_op.idx()] == False:
stack.append(next_op)
visit_path.append(next_op)
visited[next_op.idx()] = True
while len(stack) > 0: while len(stack) > 0:
top_op = stack[len(stack) - 1] #top_op = stack[len(stack) - 1]
if visited[top_op.idx()] == False: top_op = stack.pop(0)
visit_path.append(top_op)
visited[top_op.idx()] = True
next_ops = None next_ops = None
if top_op.type() == "conv2d" and param not in top_op.all_inputs(): if top_op.type() in ["conv2d", "deformable_conv"]:
next_ops = None next_ops = None
elif top_op.type() == "mul": elif top_op.type() in ["mul", "concat"]:
next_ops = None next_ops = None
else: else:
next_ops = self._get_next_unvisited_op(graph, visited, top_op) next_ops = self._get_next_unvisited_op(graph, visited, top_op)
if next_ops == None: if next_ops != None:
stack.pop() for op in next_ops:
else: if visited[op.idx()] == False:
stack += next_ops stack.append(op)
visit_path.append(op)
visited[op.idx()] = True
return visit_path return visit_path
def _get_next_unvisited_op(self, graph, visited, top_op): def _get_next_unvisited_op(self, graph, visited, top_op):
...@@ -261,7 +283,7 @@ class Pruner(): ...@@ -261,7 +283,7 @@ class Pruner():
for op in graph.next_ops(top_op): for op in graph.next_ops(top_op):
if (visited[op.idx()] == False) and (not op.is_bwd_op()): if (visited[op.idx()] == False) and (not op.is_bwd_op()):
next_ops.append(op) next_ops.append(op)
return next_ops if len(next_ops) > 0 else None return next_ops
def _get_accumulator(self, graph, param): def _get_accumulator(self, graph, param):
""" """
...@@ -317,7 +339,8 @@ class Pruner(): ...@@ -317,7 +339,8 @@ class Pruner():
if param.name() in self.pruned_list[0]: if param.name() in self.pruned_list[0]:
return return
related_ops = self._forward_search_related_op(graph, param) related_ops = self._forward_search_related_op(graph, param)
for op in related_ops:
_logger.debug("relate op: {};".format(op))
if ratio is None: if ratio is None:
assert pruned_idxs is not None assert pruned_idxs is not None
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
...@@ -339,17 +362,20 @@ class Pruner(): ...@@ -339,17 +362,20 @@ class Pruner():
only_graph=only_graph, only_graph=only_graph,
param_backup=param_backup, param_backup=param_backup,
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
corrected_idxs = pruned_idxs[:] self._prune_ops(related_ops, pruned_idxs, graph, scope, place, lazy,
only_graph, param_backup, param_shape_backup)
for idx, op in enumerate(related_ops): def _prune_ops(self, ops, pruned_idxs, graph, scope, place, lazy,
if op.type() == "conv2d" and (param not in op.all_inputs()): only_graph, param_backup, param_shape_backup):
for idx, op in enumerate(ops):
if op.type() in ["conv2d", "deformable_conv"]:
for in_var in op.all_inputs(): for in_var in op.all_inputs():
if graph.is_parameter(in_var): if graph.is_parameter(in_var):
conv_param = in_var conv_param = in_var
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator( scope, [conv_param] + self._get_accumulator(
graph, conv_param), graph, conv_param),
corrected_idxs, pruned_idxs,
pruned_axis=1, pruned_axis=1,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -363,7 +389,7 @@ class Pruner(): ...@@ -363,7 +389,7 @@ class Pruner():
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [conv_param] + self._get_accumulator( scope, [conv_param] + self._get_accumulator(
graph, conv_param), graph, conv_param),
corrected_idxs, pruned_idxs,
pruned_axis=0, pruned_axis=0,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -397,7 +423,7 @@ class Pruner(): ...@@ -397,7 +423,7 @@ class Pruner():
idx = [] idx = []
feature_map_size = fc_input.shape()[2] * fc_input.shape()[3] feature_map_size = fc_input.shape()[2] * fc_input.shape()[3]
range_idx = np.array(range(feature_map_size)) range_idx = np.array(range(feature_map_size))
for i in corrected_idxs: for i in pruned_idxs:
idx += list(range_idx + i * feature_map_size) idx += list(range_idx + i * feature_map_size)
corrected_idxs = idx corrected_idxs = idx
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
...@@ -412,23 +438,37 @@ class Pruner(): ...@@ -412,23 +438,37 @@ class Pruner():
elif op.type() == "concat": elif op.type() == "concat":
concat_inputs = op.all_inputs() concat_inputs = op.all_inputs()
last_op = related_ops[idx - 1] last_op = ops[idx - 1]
for out_var in last_op.all_outputs(): concat_idx = None
if out_var in concat_inputs: for last_op in reversed(ops):
concat_idx = concat_inputs.index(out_var) for out_var in last_op.all_outputs():
if out_var in concat_inputs:
concat_idx = concat_inputs.index(out_var)
break
if concat_idx is not None:
break
offset = 0 offset = 0
for ci in range(concat_idx): for ci in range(concat_idx):
offset += concat_inputs[ci].shape()[1] offset += concat_inputs[ci].shape()[1]
corrected_idxs = [x + offset for x in pruned_idxs] corrected_idxs = [x + offset for x in pruned_idxs]
related_ops = self._forward_search_related_op(graph, op)
for op in related_ops:
_logger.debug("concat relate op: {};".format(op))
self._prune_ops(related_ops, corrected_idxs, graph, scope,
place, lazy, only_graph, param_backup,
param_shape_backup)
elif op.type() == "batch_norm": elif op.type() == "batch_norm":
bn_inputs = op.all_inputs() bn_inputs = op.all_inputs()
mean = bn_inputs[2] in_num = len(bn_inputs)
beta = bn_inputs[0]
mean = bn_inputs[1]
alpha = bn_inputs[2]
variance = bn_inputs[3] variance = bn_inputs[3]
alpha = bn_inputs[0]
beta = bn_inputs[1]
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [mean] + self._get_accumulator(graph, mean), scope, [mean] + self._get_accumulator(graph, mean),
corrected_idxs, pruned_idxs,
pruned_axis=0, pruned_axis=0,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -437,7 +477,7 @@ class Pruner(): ...@@ -437,7 +477,7 @@ class Pruner():
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [variance] + self._get_accumulator(graph, variance), scope, [variance] + self._get_accumulator(graph, variance),
corrected_idxs, pruned_idxs,
pruned_axis=0, pruned_axis=0,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -446,7 +486,7 @@ class Pruner(): ...@@ -446,7 +486,7 @@ class Pruner():
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [alpha] + self._get_accumulator(graph, alpha), scope, [alpha] + self._get_accumulator(graph, alpha),
corrected_idxs, pruned_idxs,
pruned_axis=0, pruned_axis=0,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -455,7 +495,7 @@ class Pruner(): ...@@ -455,7 +495,7 @@ class Pruner():
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx( self._prune_parameter_by_idx(
scope, [beta] + self._get_accumulator(graph, beta), scope, [beta] + self._get_accumulator(graph, beta),
corrected_idxs, pruned_idxs,
pruned_axis=0, pruned_axis=0,
place=place, place=place,
lazy=lazy, lazy=lazy,
...@@ -491,6 +531,7 @@ class Pruner(): ...@@ -491,6 +531,7 @@ class Pruner():
self.pruned_list = [[], []] self.pruned_list = [[], []]
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
assert isinstance(param, str) or isinstance(param, unicode) assert isinstance(param, str) or isinstance(param, unicode)
_logger.info("pruning param: {}".format(param))
param = graph.var(param) param = graph.var(param)
self._forward_pruning_ralated_params( self._forward_pruning_ralated_params(
graph, graph,
...@@ -504,9 +545,10 @@ class Pruner(): ...@@ -504,9 +545,10 @@ class Pruner():
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
ops = param.outputs() ops = param.outputs()
for op in ops: for op in ops:
if op.type() == 'conv2d': if op.type() in ['conv2d', 'deformable_conv']:
brother_ops = self._search_brother_ops(graph, op) brother_ops = self._search_brother_ops(graph, op)
for broher in brother_ops: for broher in brother_ops:
_logger.debug("pruning brother: {}".format(broher))
for p in graph.get_param_by_op(broher): for p in graph.get_param_by_op(broher):
self._forward_pruning_ralated_params( self._forward_pruning_ralated_params(
graph, graph,
...@@ -534,8 +576,11 @@ class Pruner(): ...@@ -534,8 +576,11 @@ class Pruner():
stack = [] stack = []
brothers = [] brothers = []
for op in graph.next_ops(op_node): for op in graph.next_ops(op_node):
if ("conv2d" not in op.type()) and (op.type() != 'fc') and ( if ("conv2d" not in op.type()) and (
not op.is_bwd_op()) and (not op.is_opt_op()): "concat" not in op.type()) and (
"deformable_conv" not in op.type()) and (
op.type() != 'fc') and (
not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op) stack.append(op)
visited.append(op.idx()) visited.append(op.idx())
while len(stack) > 0: while len(stack) > 0:
...@@ -546,6 +591,7 @@ class Pruner(): ...@@ -546,6 +591,7 @@ class Pruner():
_logger.debug("----------go back from {} to {}----------". _logger.debug("----------go back from {} to {}----------".
format(top_op, parent)) format(top_op, parent))
if (('conv2d' in parent.type()) or if (('conv2d' in parent.type()) or
("deformable_conv" in parent.type()) or
(parent.type() == 'fc')): (parent.type() == 'fc')):
brothers.append(parent) brothers.append(parent)
else: else:
...@@ -553,10 +599,13 @@ class Pruner(): ...@@ -553,10 +599,13 @@ class Pruner():
visited.append(parent.idx()) visited.append(parent.idx())
for child in graph.next_ops(top_op): for child in graph.next_ops(top_op):
if ('conv2d' not in child.type() if ('conv2d' not in child.type()) and (
) and (child.type() != 'fc') and ( "concat" not in child.type()) and (
child.idx() not in visited) and ( 'deformable_conv' not in child.type()) and (
not child.is_bwd_op()) and (not child.is_opt_op()): child.type() != 'fc') and (
child.idx() not in visited) and (
not child.is_bwd_op()) and (
not child.is_opt_op()):
stack.append(child) stack.append(child)
visited.append(child.idx()) visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers)) _logger.debug("brothers: {}".format(brothers))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册