提交 ab25d262 编写于 作者: W wanghaoshuang

Merge branch 'fix_prune' into 'develop'

Fix pruner in only_grah mode.

See merge request !33
...@@ -102,29 +102,49 @@ class Pruner(): ...@@ -102,29 +102,49 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[0]: if params[0].name() in self.pruned_list[0]:
return return
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx( if only_graph:
params[0].name(), np.array(param_t), ratio, axis=0) pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params: for param in params:
assert isinstance(param, VarWrapper) ori_shape = param.shape()
param_t = scope.find_var(param.name()).get_tensor() if param_backup is not None and (
if param_backup is not None and (param.name() not in param_backup): param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) param_backup[param.name()] = copy.deepcopy(ori_shape)
pruned_param = self._prune_tensor( new_shape = list(ori_shape)
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) new_shape[0] -= pruned_num
if not only_graph: param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
new_shape = list(param.shape()) param.shape())
new_shape[0] = pruned_param.shape[0] new_shape = list(param.shape())
param.set_shape(new_shape) new_shape[0] = pruned_param.shape[0]
_logger.info("prune [{}] from {} to {}".format(param.name( param.set_shape(new_shape)
), ori_shape, new_shape)) _logger.info("prune [{}] from {} to {}".format(param.name(
self.pruned_list[0].append(param.name()) ), ori_shape, new_shape))
return pruned_idx self.pruned_list[0].append(param.name())
return pruned_idx
def _prune_parameter_by_idx(self, def _prune_parameter_by_idx(self,
scope, scope,
...@@ -151,26 +171,44 @@ class Pruner(): ...@@ -151,26 +171,44 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[pruned_axis]: if params[0].name() in self.pruned_list[pruned_axis]:
return return
for param in params:
assert isinstance(param, VarWrapper) if only_graph:
param_t = scope.find_var(param.name()).get_tensor() pruned_num = len(pruned_idx)
if param_backup is not None and (param.name() not in param_backup): for param in params:
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) ori_shape = param.shape()
pruned_param = self._prune_tensor( if param_backup is not None and (
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) param.name() not in param_backup):
if not only_graph: param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
new_shape = list(param.shape()) param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis] new_shape = list(param.shape())
param.set_shape(new_shape) new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
_logger.info("prune [{}] from {} to {}".format(param.name( param.set_shape(new_shape)
), ori_shape, new_shape)) _logger.info("prune [{}] from {} to {}".format(param.name(
self.pruned_list[pruned_axis].append(param.name()) ), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param): def _forward_search_related_op(self, graph, param):
""" """
...@@ -500,14 +538,16 @@ class Pruner(): ...@@ -500,14 +538,16 @@ class Pruner():
visited.append(op.idx()) visited.append(op.idx())
while len(stack) > 0: while len(stack) > 0:
top_op = stack.pop() top_op = stack.pop()
for parent in graph.pre_ops(top_op): if top_op.type().startswith("elementwise_"):
if parent.idx() not in visited and (not parent.is_bwd_op()): for parent in graph.pre_ops(top_op):
if ((parent.type() == 'conv2d') or if parent.idx() not in visited and (
(parent.type() == 'fc')): not parent.is_bwd_op()):
brothers.append(parent) if ((parent.type() == 'conv2d') or
else: (parent.type() == 'fc')):
stack.append(parent) brothers.append(parent)
visited.append(parent.idx()) else:
stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op): for child in graph.next_ops(top_op):
if (child.type() != 'conv2d') and (child.type() != 'fc') and ( if (child.type() != 'conv2d') and (child.type() != 'fc') and (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册