提交 ab25d262 编写于 作者: W wanghaoshuang

Merge branch 'fix_prune' into 'develop'

Fix pruner in only_grah mode.

See merge request !33
...@@ -102,22 +102,42 @@ class Pruner(): ...@@ -102,22 +102,42 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[0]: if params[0].name() in self.pruned_list[0]:
return return
if only_graph:
pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[0] -= pruned_num
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor() param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx( pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0) params[0].name(), np.array(param_t), ratio, axis=0)
for param in params: for param in params:
assert isinstance(param, VarWrapper) assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor() param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup): if param_backup is not None and (
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor( pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
if not only_graph:
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0] new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape) param.set_shape(new_shape)
...@@ -151,20 +171,38 @@ class Pruner(): ...@@ -151,20 +171,38 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[pruned_axis]: if params[0].name() in self.pruned_list[pruned_axis]:
return return
if only_graph:
pruned_num = len(pruned_idx)
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params: for param in params:
assert isinstance(param, VarWrapper) assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor() param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup): if param_backup is not None and (
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor( pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
if not only_graph:
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis] new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape) param.set_shape(new_shape)
...@@ -500,8 +538,10 @@ class Pruner(): ...@@ -500,8 +538,10 @@ class Pruner():
visited.append(op.idx()) visited.append(op.idx())
while len(stack) > 0: while len(stack) > 0:
top_op = stack.pop() top_op = stack.pop()
if top_op.type().startswith("elementwise_"):
for parent in graph.pre_ops(top_op): for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (not parent.is_bwd_op()): if parent.idx() not in visited and (
not parent.is_bwd_op()):
if ((parent.type() == 'conv2d') or if ((parent.type() == 'conv2d') or
(parent.type() == 'fc')): (parent.type() == 'fc')):
brothers.append(parent) brothers.append(parent)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册