From 00ba112c3d0a288d3888023d102e3121dd816541 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 19 Nov 2019 22:37:35 +0800 Subject: [PATCH] Fix pruner when in only_grah mode. --- paddleslim/prune/pruner.py | 138 ++++++++++++++++++++++++------------- 1 file changed, 89 insertions(+), 49 deletions(-) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 21b66a2a..c7cc9c9e 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -102,29 +102,49 @@ class Pruner(): """ if params[0].name() in self.pruned_list[0]: return - param_t = scope.find_var(params[0].name()).get_tensor() - pruned_idx = self._cal_pruned_idx( - params[0].name(), np.array(param_t), ratio, axis=0) - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and (param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(np.array(param_t)) - pruned_param = self._prune_tensor( - np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) - if not only_graph: + + if only_graph: + pruned_num = int(round(params[0].shape()[0] * ratio)) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[0] -= pruned_num + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return range(pruned_num) + + else: + + param_t = scope.find_var(params[0].name()).get_tensor() + pruned_idx = self._cal_pruned_idx( + params[0].name(), np.array(param_t), ratio, axis=0) + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) param_t.set(pruned_param, place) - ori_shape = param.shape() - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy(param.shape()) - new_shape = list(param.shape()) - new_shape[0] = pruned_param.shape[0] - param.set_shape(new_shape) - _logger.info("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[0].append(param.name()) - return pruned_idx + ori_shape = param.shape() + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[0] = pruned_param.shape[0] + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return pruned_idx def _prune_parameter_by_idx(self, scope, @@ -151,26 +171,44 @@ class Pruner(): """ if params[0].name() in self.pruned_list[pruned_axis]: return - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and (param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(np.array(param_t)) - pruned_param = self._prune_tensor( - np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) - if not only_graph: + + if only_graph: + pruned_num = len(pruned_idx) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[pruned_axis] -= pruned_num + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) + + else: + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) param_t.set(pruned_param, place) - ori_shape = param.shape() - - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy(param.shape()) - new_shape = list(param.shape()) - new_shape[pruned_axis] = pruned_param.shape[pruned_axis] - param.set_shape(new_shape) - _logger.info("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[pruned_axis].append(param.name()) + ori_shape = param.shape() + + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[pruned_axis] = pruned_param.shape[pruned_axis] + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) def _forward_search_related_op(self, graph, param): """ @@ -500,14 +538,16 @@ class Pruner(): visited.append(op.idx()) while len(stack) > 0: top_op = stack.pop() - for parent in graph.pre_ops(top_op): - if parent.idx() not in visited and (not parent.is_bwd_op()): - if ((parent.type() == 'conv2d') or - (parent.type() == 'fc')): - brothers.append(parent) - else: - stack.append(parent) - visited.append(parent.idx()) + if top_op.type().startswith("elementwise_"): + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and ( + not parent.is_bwd_op()): + if ((parent.type() == 'conv2d') or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) for child in graph.next_ops(top_op): if (child.type() != 'conv2d') and (child.type() != 'fc') and ( -- GitLab