diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 4f442eebf97a2da266f48ddec12c8223d5264be3..8e65630099413a1825727f4ae4674ea1b2950440 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -81,7 +81,8 @@ class Pruner(): pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0) param = graph.var(param) conv_op = param.outputs()[0] - walker = conv2d_walker(conv_op,pruned_params=pruned_params, visited=visited) + walker = conv2d_walker( + conv_op, pruned_params=pruned_params, visited=visited) walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx) merge_pruned_params = {} @@ -94,19 +95,24 @@ class Pruner(): for param_name in merge_pruned_params: for pruned_axis in merge_pruned_params[param_name]: - pruned_idx = np.concatenate(merge_pruned_params[param_name][pruned_axis]) + pruned_idx = np.concatenate(merge_pruned_params[param_name][ + pruned_axis]) param = graph.var(param_name) - _logger.debug("{}\t{}\t{}".format(param.name(), pruned_axis, len(pruned_idx))) - if param_shape_backup is not None: - origin_shape = copy.deepcopy(param.shape()) - param_shape_backup[param.name()] = origin_shape - new_shape = list(param.shape()) - new_shape[pruned_axis] -= len(pruned_idx) - param.set_shape(new_shape) + if not lazy: + _logger.debug("{}\t{}\t{}".format(param.name( + ), pruned_axis, len(pruned_idx))) + if param_shape_backup is not None: + origin_shape = copy.deepcopy(param.shape()) + param_shape_backup[param.name()] = origin_shape + new_shape = list(param.shape()) + new_shape[pruned_axis] -= len(pruned_idx) + param.set_shape(new_shape) if not only_graph: param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and (param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(np.array(param_t)) + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) try: pruned_param = self._prune_tensor( np.array(param_t), @@ -114,11 +120,11 @@ class Pruner(): pruned_axis=pruned_axis, lazy=lazy) except IndexError as e: - _logger.error("Pruning {}, but get [{}]".format(param.name( - ), e)) - + _logger.error("Pruning {}, but get [{}]".format( + param.name(), e)) + param_t.set(pruned_param, place) - + graph.update_groups_of_conv() return graph.program, param_backup, param_shape_backup def _cal_pruned_idx(self, param, ratio, axis):