diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index cd79f5b286bbb34d1d688ce515691fdfc7e8f730..21b66a2ac9f1bc4c4acbcf668ef09cdba2b2148a 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import numpy as np import paddle.fluid as fluid import copy from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import get_logger __all__ = ["Pruner"] +_logger = get_logger(__name__, level=logging.INFO) + class Pruner(): def __init__(self, criterion="l1_norm"): @@ -69,6 +73,10 @@ class Pruner(): only_graph=only_graph, param_backup=param_backup, param_shape_backup=param_shape_backup) + for op in graph.ops(): + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) return graph.program def _prune_filters_by_ratio(self, @@ -113,6 +121,8 @@ class Pruner(): new_shape = list(param.shape()) new_shape[0] = pruned_param.shape[0] param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) self.pruned_list[0].append(param.name()) return pruned_idx @@ -158,6 +168,8 @@ class Pruner(): new_shape = list(param.shape()) new_shape[pruned_axis] = pruned_param.shape[pruned_axis] param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) self.pruned_list[pruned_axis].append(param.name()) def _forward_search_related_op(self, graph, param):