未验证 提交 b81f27a1 编写于 作者: W whs 提交者: GitHub

Fix pruning for yolov4 (#313)

上级 44e359c4
......@@ -43,11 +43,11 @@ def l1_norm(group, graph):
list: A list of tuple storing l1-norm on given axis.
"""
scores = []
for name, value, axis in group:
for name, value, axis, pruned_idx in group:
reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score))
scores.append((name, axis, score, pruned_idx))
return scores
......@@ -55,7 +55,7 @@ def l1_norm(group, graph):
@CRITERION.register
def geometry_median(group, graph):
scores = []
name, value, axis = group[0]
name, value, axis, _ = group[0]
assert (len(value.shape) == 4)
def get_distance_sum(value, out_idx):
......@@ -73,8 +73,8 @@ def geometry_median(group, graph):
tmp = np.array(dist_sum_list)
for name, value, axis in group:
scores.append((name, axis, tmp))
for name, value, axis, idx in group:
scores.append((name, axis, tmp, idx))
return scores
......@@ -97,7 +97,7 @@ def bn_scale(group, graph):
assert (isinstance(graph, GraphWrapper))
# step1: Get first convolution
conv_weight, value, axis = group[0]
conv_weight, value, axis, _ = group[0]
param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0]
......@@ -111,12 +111,12 @@ def bn_scale(group, graph):
# steps3: Find scale of bn
score = None
for name, value, aixs in group:
for name, value, aixs, _ in group:
if bn_scale_param == name:
score = np.abs(value.reshape([-1]))
scores = []
for name, value, axis in group:
scores.append((name, axis, score))
for name, value, axis, idx in group:
scores.append((name, axis, score, idx))
return scores
......@@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}):
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, _ in group:
for param, axis, pruned_idx in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis))
simple_group.append((param, axis, pruned_idx))
if not repeat_group:
uniq_groups.append(simple_group)
......
......@@ -52,7 +52,7 @@ def default_idx_selector(group, ratio):
list: pruned indexes
"""
name, axis, score = group[
name, axis, score, _ = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()
......@@ -60,8 +60,9 @@ def default_idx_selector(group, ratio):
pruned_idx = sorted_idx[:pruned_num]
idxs = []
for name, axis, score in group:
idxs.append((name, axis, pruned_idx))
for name, axis, score, offsets in group:
r_idx = [i + offsets[0] for i in pruned_idx]
idxs.append((name, axis, r_idx))
return idxs
......
......@@ -77,9 +77,10 @@ class PruneWorker(object):
if op.type() in SKIP_OPS:
_logger.warn("Skip operator [{}]".format(op.type()))
return
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))
# _logger.warn(
# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
# format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
......@@ -263,6 +264,8 @@ class elementwise_op(PruneWorker):
if name == "Y":
actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0]
if len(in_var.shape()) == 1 and in_var.shape()[0] == 1:
continue
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)
......@@ -270,15 +273,17 @@ class elementwise_op(PruneWorker):
else:
if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0]
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx)
self._prune_op(op, in_var, pruned_axis - axis,
pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
......
......@@ -90,12 +90,14 @@ class Pruner():
visited = {}
pruned_params = []
for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param))
if graph.var(param) is None:
_logger.warn(
"Variable[{}] to be pruned is not in current graph.".
format(param))
continue
group = collect_convs([param], graph, visited)[0] # [(name, axis)]
group = collect_convs([param], graph,
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
if only_graph and self.idx_selector.__name__ == "default_idx_selector":
......@@ -103,30 +105,33 @@ class Pruner():
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, axis in group:
for name, axis, _ in group:
pruned_params.append((name, axis, pruned_idx))
else:
assert ((not self.pruned_weights),
"The weights have been pruned once.")
group_values = []
for name, axis in group:
for name, axis, pruned_idx in group:
values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis))
group_values.append((name, values, axis, pruned_idx))
scores = self.criterion(group_values,
graph) # [(name, axis, score)]
scores = self.criterion(
group_values, graph) # [(name, axis, score, pruned_idx)]
pruned_params.extend(self.idx_selector(scores, ratio))
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
print("{}\t{}\t{}".format(param, pruned_axis, len(pruned_idx)))
if param not in merge_pruned_params:
merge_pruned_params[param] = {}
if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx)
print("param name: stage.0.conv_layer.conv.weights; idx: {}".format(
merge_pruned_params["stage.0.conv_layer.conv.weights"][1]))
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册