提交 00ba112c 编写于 作者: W wanghaoshuang

Fix pruner when in only_grah mode.

上级 afdbc1ff
......@@ -102,29 +102,49 @@ class Pruner():
"""
if params[0].name() in self.pruned_list[0]:
return
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
if not only_graph:
if only_graph:
pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[0] -= pruned_num
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
def _prune_parameter_by_idx(self,
scope,
......@@ -151,26 +171,44 @@ class Pruner():
"""
if params[0].name() in self.pruned_list[pruned_axis]:
return
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
if not only_graph:
if only_graph:
pruned_num = len(pruned_idx)
for param in params:
ori_shape = param.shape()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
ori_shape = param.shape()
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
_logger.info("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param):
"""
......@@ -500,14 +538,16 @@ class Pruner():
visited.append(op.idx())
while len(stack) > 0:
top_op = stack.pop()
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (not parent.is_bwd_op()):
if ((parent.type() == 'conv2d') or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
if top_op.type().startswith("elementwise_"):
for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (
not parent.is_bwd_op()):
if ((parent.type() == 'conv2d') or
(parent.type() == 'fc')):
brothers.append(parent)
else:
stack.append(parent)
visited.append(parent.idx())
for child in graph.next_ops(top_op):
if (child.type() != 'conv2d') and (child.type() != 'fc') and (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册