未验证 提交 b81f27a1 编写于 作者: W whs 提交者: GitHub

Fix pruning for yolov4 (#313)

上级 44e359c4
...@@ -43,11 +43,11 @@ def l1_norm(group, graph): ...@@ -43,11 +43,11 @@ def l1_norm(group, graph):
list: A list of tuple storing l1-norm on given axis. list: A list of tuple storing l1-norm on given axis.
""" """
scores = [] scores = []
for name, value, axis in group: for name, value, axis, pruned_idx in group:
reduce_dims = [i for i in range(len(value.shape)) if i != axis] reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims)) score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score)) scores.append((name, axis, score, pruned_idx))
return scores return scores
...@@ -55,7 +55,7 @@ def l1_norm(group, graph): ...@@ -55,7 +55,7 @@ def l1_norm(group, graph):
@CRITERION.register @CRITERION.register
def geometry_median(group, graph): def geometry_median(group, graph):
scores = [] scores = []
name, value, axis = group[0] name, value, axis, _ = group[0]
assert (len(value.shape) == 4) assert (len(value.shape) == 4)
def get_distance_sum(value, out_idx): def get_distance_sum(value, out_idx):
...@@ -73,8 +73,8 @@ def geometry_median(group, graph): ...@@ -73,8 +73,8 @@ def geometry_median(group, graph):
tmp = np.array(dist_sum_list) tmp = np.array(dist_sum_list)
for name, value, axis in group: for name, value, axis, idx in group:
scores.append((name, axis, tmp)) scores.append((name, axis, tmp, idx))
return scores return scores
...@@ -97,7 +97,7 @@ def bn_scale(group, graph): ...@@ -97,7 +97,7 @@ def bn_scale(group, graph):
assert (isinstance(graph, GraphWrapper)) assert (isinstance(graph, GraphWrapper))
# step1: Get first convolution # step1: Get first convolution
conv_weight, value, axis = group[0] conv_weight, value, axis, _ = group[0]
param_var = graph.var(conv_weight) param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0] conv_op = param_var.outputs()[0]
...@@ -111,12 +111,12 @@ def bn_scale(group, graph): ...@@ -111,12 +111,12 @@ def bn_scale(group, graph):
# steps3: Find scale of bn # steps3: Find scale of bn
score = None score = None
for name, value, aixs in group: for name, value, aixs, _ in group:
if bn_scale_param == name: if bn_scale_param == name:
score = np.abs(value.reshape([-1])) score = np.abs(value.reshape([-1]))
scores = [] scores = []
for name, value, axis in group: for name, value, axis, idx in group:
scores.append((name, axis, score)) scores.append((name, axis, score, idx))
return scores return scores
...@@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}): ...@@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}):
conv_op = param.outputs()[0] conv_op = param.outputs()[0]
walker = conv2d_walker( walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited) conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[]) walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params) groups.append(pruned_params)
visited = set() visited = set()
uniq_groups = [] uniq_groups = []
for group in groups: for group in groups:
repeat_group = False repeat_group = False
simple_group = [] simple_group = []
for param, axis, _ in group: for param, axis, pruned_idx in group:
param = param.name() param = param.name()
if axis == 0: if axis == 0:
if param in visited: if param in visited:
repeat_group = True repeat_group = True
else: else:
visited.add(param) visited.add(param)
simple_group.append((param, axis)) simple_group.append((param, axis, pruned_idx))
if not repeat_group: if not repeat_group:
uniq_groups.append(simple_group) uniq_groups.append(simple_group)
......
...@@ -52,7 +52,7 @@ def default_idx_selector(group, ratio): ...@@ -52,7 +52,7 @@ def default_idx_selector(group, ratio):
list: pruned indexes list: pruned indexes
""" """
name, axis, score = group[ name, axis, score, _ = group[
0] # sort channels by the first convolution's score 0] # sort channels by the first convolution's score
sorted_idx = score.argsort() sorted_idx = score.argsort()
...@@ -60,8 +60,9 @@ def default_idx_selector(group, ratio): ...@@ -60,8 +60,9 @@ def default_idx_selector(group, ratio):
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
idxs = [] idxs = []
for name, axis, score in group: for name, axis, score, offsets in group:
idxs.append((name, axis, pruned_idx)) r_idx = [i + offsets[0] for i in pruned_idx]
idxs.append((name, axis, r_idx))
return idxs return idxs
......
...@@ -77,9 +77,10 @@ class PruneWorker(object): ...@@ -77,9 +77,10 @@ class PruneWorker(object):
if op.type() in SKIP_OPS: if op.type() in SKIP_OPS:
_logger.warn("Skip operator [{}]".format(op.type())) _logger.warn("Skip operator [{}]".format(op.type()))
return return
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". # _logger.warn(
format(op.type())) # "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
# format(op.type()))
cls = PRUNE_WORKER.get("default_walker") cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name())) self.op, op, pruned_axis, var.name()))
...@@ -263,6 +264,8 @@ class elementwise_op(PruneWorker): ...@@ -263,6 +264,8 @@ class elementwise_op(PruneWorker):
if name == "Y": if name == "Y":
actual_axis = pruned_axis - axis actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0] in_var = self.op.inputs(name)[0]
if len(in_var.shape()) == 1 and in_var.shape()[0] == 1:
continue
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx) self._prune_op(op, in_var, actual_axis, pruned_idx)
...@@ -270,15 +273,17 @@ class elementwise_op(PruneWorker): ...@@ -270,15 +273,17 @@ class elementwise_op(PruneWorker):
else: else:
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0] in_var = self.op.inputs("Y")[0]
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
if in_var.is_parameter(): if in_var.is_parameter():
self.pruned_params.append( self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx)) (in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx) self._prune_op(op, in_var, pruned_axis - axis,
pruned_idx)
elif var in self.op.inputs("Y"): elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis pruned_axis = pruned_axis + axis
for op in pre_ops: for op in pre_ops:
......
...@@ -90,12 +90,14 @@ class Pruner(): ...@@ -90,12 +90,14 @@ class Pruner():
visited = {} visited = {}
pruned_params = [] pruned_params = []
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param))
if graph.var(param) is None: if graph.var(param) is None:
_logger.warn( _logger.warn(
"Variable[{}] to be pruned is not in current graph.". "Variable[{}] to be pruned is not in current graph.".
format(param)) format(param))
continue continue
group = collect_convs([param], graph, visited)[0] # [(name, axis)] group = collect_convs([param], graph,
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0: if group is None or len(group) == 0:
continue continue
if only_graph and self.idx_selector.__name__ == "default_idx_selector": if only_graph and self.idx_selector.__name__ == "default_idx_selector":
...@@ -103,30 +105,33 @@ class Pruner(): ...@@ -103,30 +105,33 @@ class Pruner():
param_v = graph.var(param) param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num pruned_idx = [0] * pruned_num
for name, axis in group: for name, axis, _ in group:
pruned_params.append((name, axis, pruned_idx)) pruned_params.append((name, axis, pruned_idx))
else: else:
assert ((not self.pruned_weights), assert ((not self.pruned_weights),
"The weights have been pruned once.") "The weights have been pruned once.")
group_values = [] group_values = []
for name, axis in group: for name, axis, pruned_idx in group:
values = np.array(scope.find_var(name).get_tensor()) values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis)) group_values.append((name, values, axis, pruned_idx))
scores = self.criterion(group_values, scores = self.criterion(
graph) # [(name, axis, score)] group_values, graph) # [(name, axis, score, pruned_idx)]
pruned_params.extend(self.idx_selector(scores, ratio)) pruned_params.extend(self.idx_selector(scores, ratio))
merge_pruned_params = {} merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params: for param, pruned_axis, pruned_idx in pruned_params:
print("{}\t{}\t{}".format(param, pruned_axis, len(pruned_idx)))
if param not in merge_pruned_params: if param not in merge_pruned_params:
merge_pruned_params[param] = {} merge_pruned_params[param] = {}
if pruned_axis not in merge_pruned_params[param]: if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = [] merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx) merge_pruned_params[param][pruned_axis].append(pruned_idx)
print("param name: stage.0.conv_layer.conv.weights; idx: {}".format(
merge_pruned_params["stage.0.conv_layer.conv.weights"][1]))
for param_name in merge_pruned_params: for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]: for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][ pruned_idx = np.concatenate(merge_pruned_params[param_name][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册