From b849d60993376fc4aa6e0eedde3fba6a3468db01 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 27 Apr 2021 16:01:06 +0800 Subject: [PATCH] Fix pruning group conv2d (#720) --- paddleslim/core/graph_wrapper.py | 34 +- paddleslim/dygraph/prune/filter_pruner.py | 128 ++--- paddleslim/dygraph/prune/fpgm_pruner.py | 28 +- paddleslim/dygraph/prune/l1norm_pruner.py | 33 +- paddleslim/dygraph/prune/l2norm_pruner.py | 36 +- paddleslim/dygraph/prune/pruner.py | 3 +- paddleslim/dygraph/prune/pruning_plan.py | 54 +- paddleslim/dygraph/prune/var_group.py | 57 +- paddleslim/prune/__init__.py | 13 +- paddleslim/prune/collections.py | 221 ++++++++ paddleslim/prune/criterion.py | 95 ++-- paddleslim/prune/group_param.py | 104 ---- paddleslim/prune/idx_selector.py | 108 ++-- .../{prune_walker.py => prune_worker.py} | 450 ++++++++++----- paddleslim/prune/pruner.py | 113 ++-- paddleslim/prune/sensitive.py | 3 +- tests/dygraph/test_filter_pruner.py | 65 ++- tests/dygraph/test_prune.py | 2 +- tests/dygraph/test_prune_walker.py | 2 +- tests/test_dygraph_pruning_plan.py | 6 +- tests/test_group_param.py | 18 +- tests/test_prune_walker.py | 531 ++++++++++++++++-- tests/test_sensitivity.py | 1 + 23 files changed, 1423 insertions(+), 682 deletions(-) create mode 100644 paddleslim/prune/collections.py delete mode 100644 paddleslim/prune/group_param.py rename paddleslim/prune/{prune_walker.py => prune_worker.py} (58%) diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 785286e3..86fb1736 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -197,7 +197,10 @@ class OpWrapper(object): bool|int|str|float|list: The attribute value. The return value can be any valid attribute type. """ - return self._op.attr(name) + if self._op.has_attr(name): + return self._op.attr(name) + else: + return None class GraphWrapper(object): @@ -365,35 +368,6 @@ class GraphWrapper(object): Update the groups of convolution layer according to current filters. It is used after loading pruned parameters from file. """ - head_op = [] - visited = [] for op in self.ops(): if op.type() != 'conditional_block': - if len(self.pre_ops(op)) == 0: - head_op.append(op) - candidate_op = self.ops() - - def recursive_infer(op, infer=False): - if op in candidate_op: - if op.type() != 'conditional_block': - if infer: - op._op.desc.infer_shape(op._op.block.desc) - else: - visited.append(op) - candidate_op.remove(op) - for next_op in self.next_ops(op): - recursive_infer(next_op) - - # Find ops which not in the DAG, some ops, such as optimizer op, - # should be infered before normal cumputation ops. - for op in head_op: - recursive_infer(op, infer=False) - - # Infer ops which not in the DAG firstly. - candidate_op = self.ops() - for op in candidate_op: - if op not in visited and op.type() != 'conditional_block': op._op.desc.infer_shape(op._op.block.desc) - # Infer the remain ops in topological order. - for op in head_op: - recursive_infer(op, infer=True) diff --git a/paddleslim/dygraph/prune/filter_pruner.py b/paddleslim/dygraph/prune/filter_pruner.py index 1a90e251..7ff4316a 100644 --- a/paddleslim/dygraph/prune/filter_pruner.py +++ b/paddleslim/dygraph/prune/filter_pruner.py @@ -9,14 +9,14 @@ from .var_group import * from .pruning_plan import * from .pruner import Pruner from paddleslim.analysis import dygraph_flops as flops -from .var_group import VarGroup +from .var_group import DygraphPruningCollections __all__ = ['Status', 'FilterPruner'] _logger = get_logger(__name__, logging.INFO) CONV_OP_TYPE = paddle.nn.Conv2D -FILTER_DIM = [0] +FILTER_DIM = 0 CONV_WEIGHT_NAME = "weight" SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose) @@ -59,16 +59,17 @@ class FilterPruner(Pruner): def __init__(self, model, inputs, sen_file=None): super(FilterPruner, self).__init__(model, inputs) self._status = Status(sen_file) - # sensitive and var_group are just used in filter pruning - self.var_group = VarGroup(model, inputs) + # sensitive and collections are just used in filter pruning + self.collections = DygraphPruningCollections(model, inputs) # skip vars in: # 1. depthwise conv2d layer self.skip_vars = [] for sub_layer in model.sublayers(): - if isinstance(sub_layer, SKIP_LAYERS) or (isinstance( - sub_layer, paddle.nn.layer.conv.Conv2D) and - sub_layer._groups > 1): + #if isinstance(sub_layer, SKIP_LAYERS) or (isinstance( + # sub_layer, paddle.nn.layer.conv.Conv2D) and + # sub_layer._groups > 1): + if isinstance(sub_layer, SKIP_LAYERS): for param in sub_layer.parameters(): self.skip_vars.append(param.name) @@ -170,11 +171,11 @@ class FilterPruner(Pruner): break return ratios - def _round_to(self, ratios, dims=[0], factor=8): + def _round_to(self, ratios, dims=0, factor=8): ret = {} for name in ratios: ratio = ratios[name] - dim = self._var_shapes[name][dims[0]] + dim = self._var_shapes[name][dims] remained = round((1 - ratio) * dim / factor) * factor if remained == 0: remained = factor @@ -186,14 +187,14 @@ class FilterPruner(Pruner): def get_ratios_by_sensitivity(self, pruned_flops, align=None, - dims=[0], + dims=0, skip_vars=[]): """ Get a group of ratios by sensitivities. Args: pruned_flops(float): The excepted rate of FLOPs to be pruned. It should be in range (0, 1). align(int, optional): Round the size of each pruned dimension to multiple of 'align' if 'align' is not None. Default: None. - dims(list, optional): The dims to be pruned on. [0] means pruning channels of output for convolution. Default: [0]. + dims(int, optional): The dims to be pruned on. 0 means pruning channels of output for convolution. Default: 0. skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None. Returns: @@ -201,7 +202,7 @@ class FilterPruner(Pruner): """ base_flops = flops(self.model, self.inputs) - _logger.debug("Base FLOPs: {}".format(base_flops)) + _logger.info("Base FLOPs: {}".format(base_flops)) low = 0. up = 1.0 history = set() @@ -214,7 +215,6 @@ class FilterPruner(Pruner): ratios = self._round_to(ratios, dims=dims, factor=align) plan = self.prune_vars(ratios, axis=dims) c_flops = flops(self.model, self.inputs) - _logger.debug("FLOPs after pruning: {}".format(c_flops)) c_pruned_flops = (base_flops - c_flops) / base_flops plan.restore(self.model) _logger.debug("Seaching ratios, pruned FLOPs: {}".format( @@ -240,10 +240,9 @@ class FilterPruner(Pruner): sensitivities = self._status.sensitivies baseline = None ratios = np.arange(0.1, 1, step=0.1) - for group in self.var_group.groups: - var_name = group[0][0] - dims = group[0][1] - + for _collection in self.collections: + var_name = _collection.master_name + dims = _collection.master_axis if target_vars is not None and var_name not in target_vars: continue if skip_vars is not None and var_name in skip_vars: @@ -282,7 +281,6 @@ class FilterPruner(Pruner): self.restore() ratios, pruned_flops = self.get_ratios_by_sensitivity( pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars) - _logger.debug("ratios: {}".format(ratios)) self.plan = self.prune_vars(ratios, FILTER_DIM) self.plan._pruned_flops = pruned_flops return self.plan @@ -291,73 +289,60 @@ class FilterPruner(Pruner): if self.plan is not None: self.plan.restore(self.model) - def cal_mask(self, var_name, pruned_ratio, group): - """ - - { - var_name: { - 'layer': sub_layer, - 'var': variable, - 'value': np.array([]), - 'pruned_dims': [1], - } - } - """ + def cal_mask(self, pruned_ratio, collection): raise NotImplemented("cal_mask is not implemented") - def prune_var(self, var_name, pruned_dims, pruned_ratio, apply="impretive"): + def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"): """ Pruning a variable. Parameters: var_name(str): The name of variable. - pruned_dims(list): The axies to be pruned. For convolution with format [out_c, in_c, k, k], - 'axis=[0]' means pruning filters and 'axis=[0, 1]' means pruning kernels. + pruned_axis(int): The axis to be pruned. For convolution with format [out_c, in_c, k, k], + 'axis=0' means pruning filters. pruned_ratio(float): The ratio of pruned values in one variable. + apply(str): How to apply pruning plan to graph. It can be 'impretive', 'lazy' or None. None + means just returning an instance of 'PruningPlan' but not applying it to graph. Returns: plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'. """ + pruned_axis = pruned_axis[0] if isinstance(pruned_axis, + list) else pruned_axis + assert (isinstance(pruned_axis, int)) if var_name in self.skip_vars: _logger.warn( - f"{var_name} is skiped beacause it is not support for pruning derectly." + f"{var_name} is skiped beacause it is not supported for pruning directly." ) return - if isinstance(pruned_dims, int): - pruned_dims = [pruned_dims] - group = self.var_group.find_group(var_name, pruned_dims) - _logger.debug("found group with {}: {}".format(var_name, group)) + collection = self.collections.find_collection_by_master(var_name, + pruned_axis) plan = PruningPlan(self.model.full_name) - group_dict = {} - for sub_layer in self.model.sublayers(): - for param in sub_layer.parameters(include_sublayers=False): - if param.name in group: - group_dict[param.name] = group[param.name] - # Varibales can be pruned on multiple axies. - for _item in group_dict[param.name]: - _item.update({ - 'layer': sub_layer, - 'var': param, - 'value': np.array(param.value().get_tensor()) - }) - _logger.debug(f"set value of {param.name} into group") - - mask = self.cal_mask(var_name, pruned_ratio, group_dict) - for _name in group_dict: + if collection is None: + _logger.debug( + f"Can not find collection with master ['name': {var_name}, 'axis': {pruned_axis}]" + ) + return plan + _logger.info( + f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}" + ) + + mask = self.cal_mask(pruned_ratio, collection) + for _detail in collection.all_pruning_details(): # Varibales can be pruned on multiple axies. - for _item in group_dict[_name]: - src_mask = copy.deepcopy(mask) - dims = _item['pruned_dims'] - transforms = _item['transforms'] - var_shape = _item['var'].shape - if isinstance(dims, int): - dims = [dims] - for trans in transforms: - src_mask = self._transform_mask(src_mask, trans) - current_mask = src_mask - assert len(current_mask) == var_shape[dims[ - 0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}" - plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) + src_mask = copy.deepcopy(mask) + var_shape = _detail.var.shape() + for tran in _detail.transform: + src_mask = self._transform_mask(src_mask, tran) + current_mask = src_mask + groups = _detail.op.attr('groups') + if groups is None or groups == 1: + assert len(current_mask) == var_shape[ + _detail. + axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}" + plan.add(_detail.name, + PruningMask(_detail.axis, current_mask, pruned_ratio, + _detail.op)) if apply == "lazy": plan.apply(self.model, lazy=True) elif apply == "impretive": @@ -371,17 +356,8 @@ class FilterPruner(Pruner): target_start = transform['target_start'] target_end = transform['target_end'] target_len = transform['target_len'] - stride = transform['stride'] mask = mask[src_start:src_end] - - mask = mask.repeat(stride) if stride > 1 else mask - dst_mask = np.ones([target_len]) - # for depthwise conv2d with: - # input shape: (1, 4, 32, 32) - # filter shape: (32, 1, 3, 3) - # groups: 4 - # if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8. expand = int((target_end - target_start) / len(mask)) dst_mask[target_start:target_end] = list(mask) * expand elif "stride" in transform: diff --git a/paddleslim/dygraph/prune/fpgm_pruner.py b/paddleslim/dygraph/prune/fpgm_pruner.py index cb825a05..60a37fd2 100644 --- a/paddleslim/dygraph/prune/fpgm_pruner.py +++ b/paddleslim/dygraph/prune/fpgm_pruner.py @@ -15,24 +15,38 @@ class FPGMFilterPruner(FilterPruner): def __init__(self, model, inputs, sen_file=None): super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file) - def cal_mask(self, var_name, pruned_ratio, group): - for _item in group[var_name]: - if _item['pruned_dims'] == [0]: - value = _item['value'] - pruned_dims = _item['pruned_dims'] + def cal_mask(self, pruned_ratio, collection): + var_name = collection.master_name + pruned_axis = collection.master_axis + value = collection.values[var_name] + groups = 1 + for _detail in collection.all_pruning_details(): + assert (isinstance(_detail.axis, int)) + if _detail.axis == 1: + _groups = _detail.op.attr('groups') + if _groups is not None and _groups > 1: + groups = _groups + break + dist_sum_list = [] for out_i in range(value.shape[0]): dist_sum = self.get_distance_sum(value, out_i) dist_sum_list.append(dist_sum) scores = np.array(dist_sum_list) + if groups > 1: + scores = scores.reshape([groups, -1]) + scores = np.mean(scores, axis=1) + sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] - mask_shape = [value.shape[i] for i in pruned_dims] + mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") + if groups > 1: + mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 - return mask + return mask.reshape(mask_shape) def get_distance_sum(self, value, out_idx): w = value.view() diff --git a/paddleslim/dygraph/prune/l1norm_pruner.py b/paddleslim/dygraph/prune/l1norm_pruner.py index 358d5fcf..0d8f1283 100644 --- a/paddleslim/dygraph/prune/l1norm_pruner.py +++ b/paddleslim/dygraph/prune/l1norm_pruner.py @@ -16,19 +16,32 @@ class L1NormFilterPruner(FilterPruner): super(L1NormFilterPruner, self).__init__( model, inputs, sen_file=sen_file) - def cal_mask(self, var_name, pruned_ratio, group): - for _item in group[var_name]: - if _item['pruned_dims'] == [0]: - value = _item['value'] - pruned_dims = _item['pruned_dims'] - reduce_dims = [ - i for i in range(len(value.shape)) if i not in pruned_dims - ] + def cal_mask(self, pruned_ratio, collection): + var_name = collection.master_name + pruned_axis = collection.master_axis + value = collection.values[var_name] + groups = 1 + for _detail in collection.all_pruning_details(): + assert (isinstance(_detail.axis, int)) + if _detail.axis == 1: + _groups = _detail.op.attr('groups') + if _groups is not None and _groups > 1: + groups = _groups + break + + reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims)) + if groups > 1: + l1norm = l1norm.reshape([groups, -1]) + l1norm = np.mean(l1norm, axis=1) + sorted_idx = l1norm.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] - mask_shape = [value.shape[i] for i in pruned_dims] + + mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") + if groups > 1: + mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 - return mask + return mask.reshape(mask_shape) diff --git a/paddleslim/dygraph/prune/l2norm_pruner.py b/paddleslim/dygraph/prune/l2norm_pruner.py index 72453923..da527c05 100644 --- a/paddleslim/dygraph/prune/l2norm_pruner.py +++ b/paddleslim/dygraph/prune/l2norm_pruner.py @@ -16,22 +16,32 @@ class L2NormFilterPruner(FilterPruner): super(L2NormFilterPruner, self).__init__( model, inputs, sen_file=sen_file) - def cal_mask(self, var_name, pruned_ratio, group): - # find information of pruning on output channels - for _item in group[var_name]: - if _item['pruned_dims'] == [0]: - value = _item['value'] - pruned_dims = _item['pruned_dims'] - reduce_dims = [ - i for i in range(len(value.shape)) if i not in pruned_dims - ] - - # scores = np.mean(np.abs(value), axis=tuple(reduce_dims)) + def cal_mask(self, pruned_ratio, collection): + var_name = collection.master_name + pruned_axis = collection.master_axis + value = collection.values[var_name] + groups = 1 + for _detail in collection.all_pruning_details(): + assert (isinstance(_detail.axis, int)) + if _detail.axis == 1: + _groups = _detail.op.attr('groups') + if _groups is not None and _groups > 1: + groups = _groups + break + + reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) + if groups > 1: + scores = scores.reshape([groups, -1]) + scores = np.mean(scores, axis=1) + sorted_idx = scores.argsort() pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_idx = sorted_idx[:pruned_num] - mask_shape = [value.shape[i] for i in pruned_dims] + + mask_shape = [value.shape[pruned_axis]] mask = np.ones(mask_shape, dtype="int32") + if groups > 1: + mask = mask.reshape([groups, -1]) mask[pruned_idx] = 0 - return mask + return mask.reshape(mask_shape) diff --git a/paddleslim/dygraph/prune/pruner.py b/paddleslim/dygraph/prune/pruner.py index 3d5bfe20..88597785 100644 --- a/paddleslim/dygraph/prune/pruner.py +++ b/paddleslim/dygraph/prune/pruner.py @@ -39,11 +39,12 @@ class Pruner(object): Args: ratios(dict): The key is the name of variable to be pruned and the value is the pruned ratio. - axis(list): The dimensions to be pruned on. + axis(int): The dimension to be pruned on. Returns: plan(PruningPlan): The pruning plan. """ + axis = axis[0] if isinstance(axis, list) else axis global_plan = PruningPlan(self.model.full_name) for var, ratio in ratios.items(): if not global_plan.contains(var, axis): diff --git a/paddleslim/dygraph/prune/pruning_plan.py b/paddleslim/dygraph/prune/pruning_plan.py index 9aa40e76..c6e91a65 100644 --- a/paddleslim/dygraph/prune/pruning_plan.py +++ b/paddleslim/dygraph/prune/pruning_plan.py @@ -10,27 +10,17 @@ __all__ = ['PruningPlan', 'PruningMask'] class PruningMask(): - def __init__(self, dims, mask, ratio): + def __init__(self, dims, mask, ratio, op): + assert (isinstance(dims, int)) self._dims = dims self._mask = mask self._pruned_ratio = ratio + self._op = op @property def dims(self): return self._dims - @dims.setter - def dims(self, value): - if not isinstance(value, collections.Iterator): - raise ValueError( - "The dims of PruningMask must be instance of collections.Iterator." - ) - if self._mask is not None: - assert len(self._mask.shape) == len( - value - ), "The length of value must be same with length of mask's shape in current PruningMask instance." - self._dims = list(value) - @property def mask(self): return self._mask @@ -128,8 +118,7 @@ class PruningPlan(): _logger.debug("Backup values of {} into buffers.". format(param.name)) expand_mask_shape = [1] * len(value.shape) - for i in dims: - expand_mask_shape[i] = value.shape[i] + expand_mask_shape[dims] = value.shape[dims] _logger.debug("Expanded mask shape: {}".format( expand_mask_shape)) expand_mask = mask.reshape(expand_mask_shape).astype( @@ -158,13 +147,25 @@ class PruningPlan(): if param.name in self._masks: for _mask in self._masks[param.name]: dims = _mask.dims + assert (isinstance(dims, int)) mask = _mask.mask - assert len( - dims - ) == 1, "Imperative mode only support for pruning on one dimension, but get dims {} when pruning parameter {}".format( - dims, param.name) + bool_mask = np.array(mask).astype(bool) t_value = param.value().get_tensor() value = np.array(t_value).astype("float32") + + groups = _mask._op.attr('groups') + if dims == 1 and groups is not None and groups > 1 and len( + value.shape) == 4: + filter_size = value.shape[1] + except_num = np.sum(bool_mask) + assert (except_num % filter_size == 0) + new_groups = int(except_num / filter_size) + sub_layer._origin_groups = sub_layer._groups + sub_layer._groups = new_groups + _logger.info("change groups from {} to {} for {}.". + format(groups, new_groups, param.name)) + continue + # The name of buffer can not contains "." backup_name = param.name.replace(".", "_") + "_backup" if backup_name not in sub_layer._buffers: @@ -172,9 +173,8 @@ class PruningPlan(): paddle.to_tensor(value)) _logger.debug("Backup values of {} into buffers.". format(param.name)) - bool_mask = np.array(mask).astype(bool) pruned_value = np.apply_along_axis( - lambda data: data[bool_mask], dims[0], value) + lambda data: data[bool_mask], dims, value) p = t_value._place() if p.is_cpu_place(): place = paddle.CPUPlace() @@ -186,18 +186,6 @@ class PruningPlan(): place = paddle.CUDAPlace(p.gpu_device_id()) t_value.set(pruned_value, place) - if isinstance( - sub_layer, paddle.nn.layer.conv.Conv2D - ) and sub_layer._groups > 1 and len(param.shape) == 4: - assert param.shape[ - 1] == 1, "It just supports depthwise conv2d when groups > 1." - new_groups = int(bool_mask.sum() * - sub_layer._groups / len(bool_mask)) - _logger.debug( - "Update groups of depthwise conv2d form {} to {}". - format(sub_layer._groups, new_groups)) - sub_layer._origin_groups = sub_layer._groups - sub_layer._groups = new_groups # for training if param.trainable: diff --git a/paddleslim/dygraph/prune/var_group.py b/paddleslim/dygraph/prune/var_group.py index 1f9a01ee..9965d8ca 100644 --- a/paddleslim/dygraph/prune/var_group.py +++ b/paddleslim/dygraph/prune/var_group.py @@ -3,15 +3,15 @@ import logging import paddle from paddle.fluid.dygraph import TracedLayer from paddleslim.core import GraphWrapper, dygraph2program -from paddleslim.prune import collect_convs +from paddleslim.prune import PruningCollections from paddleslim.common import get_logger -__all__ = ["VarGroup"] +__all__ = ["DygraphPruningCollections"] _logger = get_logger(__name__, level=logging.INFO) -class VarGroup(): +class DygraphPruningCollections(PruningCollections): """ A tool used to parse dygraph and store information of variables' relationship. Args: @@ -20,40 +20,29 @@ class VarGroup(): """ def __init__(self, model, inputs): - self.groups = [] - self._parse_model(model, inputs) - - def _to_dict(self, group): - ret = {} - for _name, _axis, _transforms in group: - if isinstance(_axis, int): - _axis = [_axis] - if _name not in ret: - ret[_name] = [] - # Variable can be pruned on multiple axies. - ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms}) - return ret - - def find_group(self, var_name, axis): - for group in self.groups: - for _name, _axis, _stride in group: - if isinstance(_axis, int): - _axis = [_axis] - if _name == var_name and _axis == axis: - return self._to_dict(group) - - def _parse_model(self, model, inputs): _logger.debug("Parsing model with input: {}".format(inputs)) # model can be in training mode, because some model contains auxiliary parameters for training. program = dygraph2program(model, inputs=inputs) graph = GraphWrapper(program) - visited = {} - for name, param in model.named_parameters(): - group = collect_convs([param.name], graph, - visited)[0] # [(name, axis, pruned_idx)] - if len(group) > 0: - self.groups.append(group) - _logger.info("Found {} groups.".format(len(self.groups))) + params = [ + _param.name for _param in model.parameters() + if len(_param.shape) == 4 + ] + self._collections = self.create_pruning_collections(params, graph) + _logger.info("Found {} collections.".format(len(self._collections))) + + _name2values = {} + for param in model.parameters(): + _name2values[param.name] = np.array(param.value().get_tensor()) + for collection in self._collections: + collection.values = _name2values + + def find_collection_by_master(self, var_name, axis): + for _collection in self._collections: + if _collection.master['name'] == var_name and _collection.master[ + 'axis'] == axis: + return _collection def __str__(self): - return "\n".join([str(group) for group in self.groups]) + return "\n".join( + [str(_collection) for _collection in self._collections]) diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 7542031c..2a4015a5 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -19,17 +19,16 @@ from .auto_pruner import * from ..prune import auto_pruner from .sensitive import * from ..prune import sensitive -from .prune_walker import * -from ..prune import prune_walker +from .prune_worker import * +from ..prune import prune_worker from .prune_io import * from ..prune import prune_io -from .group_param import * -from ..prune import group_param from .criterion import * from ..prune import criterion +from .collections import * +from ..prune import collections from .unstructured_pruner import * from ..prune import unstructured_pruner - from .idx_selector import * from ..prune import idx_selector __all__ = [] @@ -37,9 +36,9 @@ __all__ = [] __all__ += pruner.__all__ __all__ += auto_pruner.__all__ __all__ += sensitive.__all__ -__all__ += prune_walker.__all__ +__all__ += prune_worker.__all__ __all__ += prune_io.__all__ -__all__ += group_param.__all__ __all__ += criterion.__all__ __all__ += unstructured_pruner.__all__ __all__ += idx_selector.__all__ +__all__ += collections.__all__ diff --git a/paddleslim/prune/collections.py b/paddleslim/prune/collections.py new file mode 100644 index 00000000..4f9d4bb0 --- /dev/null +++ b/paddleslim/prune/collections.py @@ -0,0 +1,221 @@ +"""Define some functions to collect ralated parameters into groups.""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +from ..core import GraphWrapper, VarWrapper +from ..common import get_logger +from .prune_worker import PRUNE_WORKER, UnsupportOpError + +__all__ = [ + 'PruningDetails', 'PruningCollection', 'PruningCollections', + 'StaticPruningCollections' +] + +_logger = get_logger(__name__, level=logging.INFO) + + +class PruningDetails(object): + """ + The description of one pruning operation. + Args: + var(VarWrapper): The variable to be pruned. + axis(int): The axis to be pruned on. + transform(dict): Information used to convert pruned indices of master + tensor to indices of current tensor. + op(OpWrapper): The operator with current tensor as input. + is_parameter(bool): whether the tensor is parameter. Default: True. + """ + + def __init__(self, var, axis, transform, op, is_parameter=True): + assert (isinstance(var, VarWrapper), + "name should be VarWrapper, but get type = ".format(type(var))) + assert (isinstance(axis, int)) + self.name = var.name() + self.var = var + self.axis = axis + self.transform = transform + self.op = op + self.is_parameter = is_parameter + + def __eq__(self, other): + if self.name != other.name: + return False + if self.axis != other.axis: + return False + if self.transform != other.transform: + return False + return True + + +class PruningCollection(object): + """ + A group of pruning operations. + + conv1-->conv2-->batch_norm + + For the network defined above, if weight of conv1 is pruned on 0-axis, + weight of'conv2' should be pruned on 1-axis. The pruning operations on 0-axis of + 'conv1' and those on 1-axis of 'conv2' is a collection. And the {'name': conv1.weight_name, 'axis': 0} + is the master of current collection. + + Args: + master(dict): The master pruning operation. + """ + + def __init__(self, master=None): + self._master = master + self.master_name = master['name'] + self.master_axis = master['axis'] + self._nodes = {} + + def variables(self): + """ + Get all tensors to be pruned in current collection. + Returns: + list: Names of tensor to be pruned. + """ + return list(self._nodes.keys()) + + def add(self, node): + """ + Add a pruning operation into current collention. + Args: + node(PruningDetails): Pruning operation to be added into current collection. + """ + assert (isinstance(node, PruningDetails)) + # the first added pruning operation will be master. + self._master = { + "name": node.name, + "axis": node.aixs + } if self._master is None else self._master + if node.name not in self._nodes: + self._nodes[node.name] = [] + if node not in self._nodes[node.name]: + self._nodes[node.name].append(node) + + @property + def master(self): + return self._master + + def all_pruning_details(self): + """ + Get all pruning operations in current collection. + Returns: + list: Pruning operations. + """ + ret = [] + for _items in self._nodes.values(): + ret.extend(_items) + return ret + + +class PruningCollections(object): + def __init__(self): + self._collections = None + + def __iter__(self): + return iter(self._collections) + + def create_pruning_collections(self, params, graph, skip_stranger=True): + """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. + A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. + + .. code-block:: text + + conv1->conv2->conv3->conv4 + + As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as: + + .. code-block:: python + + [("conv1", 0), ("conv2", 1)] + + If `params` is `["conv1", "conv2"]`, then the returned groups is: + + .. code-block:: python + + [[("conv1", 0), ("conv2", 1)], + [("conv2", 0), ("conv3", 1)]] + + Args: + params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. + graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. + skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True. + + Returns: + list: The groups. + + """ + if not isinstance(graph, GraphWrapper): + graph = GraphWrapper(graph) + visited = {} + collections = [] + unsupported_warnings = set() + for _param in params: + pruned_params = [] + param = graph.var(_param) + if param is None: + _logger.warning( + f"Couldn't find relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correct mode and contains {_param} if you are using dynamic API of PaddlePaddle." + ) + continue + target_op = param.outputs()[0] + if target_op.type() == 'conditional_block': + for op in param.outputs(): + if op.type() in PRUNE_WORKER._module_dict.keys(): + cls = PRUNE_WORKER.get(op.type()) + worker = cls(op, + pruned_params=pruned_params, + visited=visited, + skip_stranger=skip_stranger) + break + else: + cls = PRUNE_WORKER.get(target_op.type()) + if cls is None: + _logger.warning("No worker for operator: {}".format( + target_op.type())) + continue + worker = cls(target_op, + pruned_params=pruned_params, + visited=visited, + skip_stranger=skip_stranger) + try: + visited_backup = copy.deepcopy(worker.visited) + worker.prune(param, pruned_axis=0, pruned_idx=[]) + except UnsupportOpError as e: + visited.clear() + visited.update(visited_backup) + unsupported_warnings.add(e.args) + else: + if len(pruned_params) != 0: + collection = PruningCollection(master=({ + "name": param.name(), + "axis": 0 + })) + for _param, _axis, _transform, _op in pruned_params: + collection.add( + PruningDetails(_param, _axis, _transform, _op)) + collections.append(collection) + for warn in unsupported_warnings: + _logger.warning(warn) + self._collections = collections + return self._collections + + +class StaticPruningCollections(PruningCollections): + def __init__(self, params, graph, skip_stranger=True): + super(StaticPruningCollections, self).__init__() + self._collections = self.create_pruning_collections( + params, graph, skip_stranger=skip_stranger) diff --git a/paddleslim/prune/criterion.py b/paddleslim/prune/criterion.py index a32ec6c9..aae0b629 100644 --- a/paddleslim/prune/criterion.py +++ b/paddleslim/prune/criterion.py @@ -27,7 +27,7 @@ CRITERION = Registry('criterion') @CRITERION.register -def l1_norm(group, graph): +def l1_norm(group, values, graph): """Compute l1-norm scores of parameter on given axis. This function return a list of parameters' l1-norm scores on given axis. @@ -35,28 +35,44 @@ def l1_norm(group, graph): and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`. Args: - group(list): A group of parameters. The first parameter of the group is convolution layer's weight - while the others are parameters affected by pruning the first one. Each parameter in group - is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and - and `values` is the values of parameter and `axis` is the axis reducing on pruning on. + group(Group): A group of pruning operations. + values(dict): The key is the name of tensor in group, and the value of dict is the + values of tensor. + graph(GraphWrapper): The graph stores structure information of network. + Returns: - list: A list of tuple storing l1-norm on given axis. + dict: The key is name of tensor, the value is a dict + with axis as key and l1-norm scores as value. """ - scores = [] - for name, value, axis, pruned_idx in group: - + scores = {} + + for pruning_details in group.all_pruning_details(): + name = pruning_details.name + if name not in values: + _logger.warning( + "The value of tensor '{}' is not found, so it will not be used when evaluating importance of pruned structures.". + format(name)) + continue + value = values[name] + axis = pruning_details.axis reduce_dims = [i for i in range(len(value.shape)) if i != axis] score = np.sum(np.abs(value), axis=tuple(reduce_dims)) - scores.append((name, axis, score, pruned_idx)) - + if name not in scores: + scores[name] = {} + scores[name][axis] = score return scores @CRITERION.register -def geometry_median(group, graph): - scores = [] - name, value, axis, _ = group[0] - assert (len(value.shape) == 4) +def geometry_median(group, values, graph): + name = group.master["name"] + axis = group.master["axis"] + if name not in values: + _logger.warning("The value of tensor '{}' is not found.") + return None + value = values[name] + assert (len(value.shape) == 4, + "geometry_median only support for weight of conv2d.") def get_distance_sum(value, out_idx): w = value.view() @@ -73,31 +89,26 @@ def geometry_median(group, graph): tmp = np.array(dist_sum_list) - for name, value, axis, idx in group: - scores.append((name, axis, tmp, idx)) + scores = {} + for pruning_details in group.all_pruning_details(): + name = pruning_details.name + axis = pruning_details.axis + if name not in scores: + scores[name] = {} + scores[name][axis] = tmp return scores @CRITERION.register -def bn_scale(group, graph): - """Compute l1-norm scores of parameter on given axis. - - This function return a list of parameters' l1-norm scores on given axis. - Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name - and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`. - - Args: - group(list): A group of parameters. The first parameter of the group is convolution layer's weight - while the others are parameters affected by pruning the first one. Each parameter in group - is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and - and `values` is the values of parameter and `axis` is the axis reducing on pruning on. - Returns: - list: A list of tuple storing l1-norm on given axis. +def bn_scale(group, values, graph): + """Compute scores by scales of batch_norm layer. """ assert (isinstance(graph, GraphWrapper)) # step1: Get first convolution - conv_weight, value, axis, _ = group[0] + conv_weight = group.master["name"] + axis = group.master["axis"] + value = values[conv_weight] param_var = graph.var(conv_weight) conv_op = param_var.outputs()[0] @@ -111,12 +122,16 @@ def bn_scale(group, graph): # steps3: Find scale of bn score = None - for name, value, aixs, _ in group: - if bn_scale_param == name: - score = np.abs(value.reshape([-1])) - - scores = [] - for name, value, axis, idx in group: - scores.append((name, axis, score, idx)) - + if bn_scale_param not in values: + raise SystemExit("Can't find values of scales in BatchNorm.") + value = values[bn_scale_param] + score = np.abs(value.reshape([-1])) + + scores = {} + for pruning_details in group.all_pruning_details(): + name = pruning_details.name + axis = pruning_details.axis + if name not in scores: + scores[name] = {} + scores[name][axis] = score return scores diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py deleted file mode 100644 index 9a406e31..00000000 --- a/paddleslim/prune/group_param.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Define some functions to collect ralated parameters into groups.""" -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from ..core import GraphWrapper -from ..common import get_logger -from .prune_walker import PRUNE_WORKER - -__all__ = ["collect_convs"] - -_logger = get_logger(__name__, level=logging.INFO) - - -def collect_convs(params, graph, visited={}): - """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. - A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. - - .. code-block:: text - - conv1->conv2->conv3->conv4 - - As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as: - - .. code-block:: python - - [("conv1", 0), ("conv2", 1)] - - If `params` is `["conv1", "conv2"]`, then the returned groups is: - - .. code-block:: python - - [[("conv1", 0), ("conv2", 1)], - [("conv2", 0), ("conv3", 1)]] - - Args: - params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. - graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. - - Returns: - list>: The groups. - - """ - if not isinstance(graph, GraphWrapper): - graph = GraphWrapper(graph) - groups = [] - for _param in params: - pruned_params = [] - param = graph.var(_param) - if param is None: - _logger.warning( - f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle." - ) - groups.append([]) - continue - target_op = param.outputs()[0] - if target_op.type() == 'conditional_block': - for op in param.outputs(): - if op.type() in PRUNE_WORKER._module_dict.keys(): - cls = PRUNE_WORKER.get(op.type()) - walker = cls(op, - pruned_params=pruned_params, - visited=visited) - break - else: - cls = PRUNE_WORKER.get(target_op.type()) - if cls is None: - _logger.info("No walker for operator: {}".format(target_op.type( - ))) - groups.append(pruned_params) - continue - walker = cls(target_op, - pruned_params=pruned_params, - visited=visited) - - walker.prune(param, pruned_axis=0, pruned_idx=[]) - groups.append(pruned_params) - visited = set() - uniq_groups = [] - for group in groups: - repeat_group = False - simple_group = [] - for param, axis, pruned_idx in group: - param = param.name() - if axis == 0: - if param in visited: - repeat_group = True - else: - visited.add(param) - simple_group.append((param, axis, pruned_idx)) - if not repeat_group: - uniq_groups.append(simple_group) - return uniq_groups diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py index 57f21383..d6c19ebf 100644 --- a/paddleslim/prune/idx_selector.py +++ b/paddleslim/prune/idx_selector.py @@ -26,75 +26,80 @@ IDX_SELECTOR = Registry('idx_selector') @IDX_SELECTOR.register -def default_idx_selector(group, ratio): - """Get the pruned indexes by given ratio. +def default_idx_selector(group, scores, ratios): + """Get the pruned indices by scores of master tensor. - This function return a list of parameters' pruned indexes on given axis. - Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name - and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. + This function return a list of parameters' pruned indices on given axis. + Each element of list is a tuple with format (name, axis, indices) + in which 'name' is parameter's name and 'axis' is the axis pruning on and + `indices` is indices to be pruned. Args: - group(list): A group of parameters. The first parameter of the group is convolution layer's weight - while the others are parameters affected by pruning the first one. Each parameter in group - is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and - `axis` is the axis pruning on and `score` is a np.array storing the importance of strucure - on `axis`. Show as below: - - .. code-block: text - - [("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])] - - The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so - `[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights" - while axis is 0. + group(Group): A group of pruning operations. + scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value. + ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio. Returns: - list: pruned indexes + list: pruned indices with format (name, axis, pruned_indices). """ - name, axis, score, _ = group[ - 0] # sort channels by the first convolution's score + # sort channels by the master convolution's score + name = group.master["name"] + axis = group.master["axis"] + score = scores[name][axis] + + # get max convolution groups attribution + max_groups = 1 + for pruning_details in group.all_pruning_details(): + groups = pruning_details.op.attr("groups") + if groups is not None and groups > max_groups: + max_groups = groups + if max_groups > 1: + score = score.reshape([max_groups, -1]) + group_size = score.shape[1] + # get score for each group of channels + score = np.mean(score, axis=1) sorted_idx = score.argsort() - + ratio = ratios[name] pruned_num = int(round(len(sorted_idx) * ratio)) pruned_idx = sorted_idx[:pruned_num] - idxs = [] - for name, axis, score, transforms in group: - idxs.append((name, axis, pruned_idx, transforms)) - return idxs + # convert indices of channel groups to indices of channels. + if max_groups > 1: + correct_idx = [] + for idx in pruned_idx: + for offset in range(group_size): + correct_idx.append(idx * group_size + offset) + pruned_idx = correct_idx[:] + ret = [] + for _pruning_details in group.all_pruning_details(): + ret.append((_pruning_details.name, _pruning_details.axis, pruned_idx, + _pruning_details.transform)) + return ret @IDX_SELECTOR.register -def optimal_threshold(group, ratio): - """Get the pruned indexes by given ratio. +def optimal_threshold(group, scores, ratios): + """Get the pruned indices by scores of master tensor. - This function return a list of parameters' pruned indexes on given axis. - Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name - and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. + This function return a list of parameters' pruned indices on given axis. + Each element of list is a tuple with format (name, axis, indices) + in which 'name' is parameter's name and 'axis' is the axis pruning on and + `indices` is indices to be pruned. Args: - group(list): A group of parameters. The first parameter of the group is convolution layer's weight - while the others are parameters affected by pruning the first one. Each parameter in group - is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and - `axis` is the axis pruning on and `score` is a np.array storing the importance of strucure - on `axis`. Show as below: - - .. code-block: text - - [("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])] - - The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so - `[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights" - while axis is 0. + group(Group): A group of pruning operations. + scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value. + ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio. Returns: - - list: pruned indexes - + list: pruned indices with format (name, axis, pruned_indices). """ - name, axis, score, _ = group[ - 0] # sort channels by the first convolution's score + # sort channels by the master tensor + name = group.master["name"] + axis = group.master["axis"] + score = scores[name][axis] + ratio = ratios[name] score[score < 1e-18] = 1e-18 score_sorted = np.sort(score) @@ -110,6 +115,7 @@ def optimal_threshold(group, ratio): pruned_idx = np.squeeze(np.argwhere(score < th)) idxs = [] - for name, axis, score, transforms in group: - idxs.append((name, axis, pruned_idx, transforms)) + for _pruning_details in group.all_pruning_details(): + idxs.append((_pruning_details.name, _pruning_details.axis, pruned_idx, + _pruning_details.transform)) return idxs diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_worker.py similarity index 58% rename from paddleslim/prune/prune_walker.py rename to paddleslim/prune/prune_worker.py index 85439a49..8a07d099 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_worker.py @@ -12,35 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import logging import numpy as np from ..core import Registry from ..common import get_logger -__all__ = ["PRUNE_WORKER", "conv2d"] +__all__ = ["PRUNE_WORKER", "conv2d", "UnsupportOpError"] _logger = get_logger(__name__, level=logging.INFO) PRUNE_WORKER = Registry('prune_worker') -SKIP_OPS = ["conditional_block"] +SKIPPED_OPS = ['shape', 'reduce_mean'] + +# operators in OPS_UNCHANGE_SHAPE will be visited by default worker +# who keep shape of output same with shape of input. +OPS_UNCHANGE_SHAPE = os.getenv('OPS_UNCHANGE_SHAPE', None) +OPS_UNCHANGE_SHAPE = [] if OPS_UNCHANGE_SHAPE is None else OPS_UNCHANGE_SHAPE.strip( +).split(",") +OPS_UNCHANGE_SHAPE += [ + 'nearest_interp_v2', + 'roi_align', + 'sigmoid', + 'swish', + 'pad3d', + 'bilinear_interp_v2', + 'dropout', + 'cast', + 'hard_swish', + 'hard_sigmoid', +] + + +class UnsupportOpError(Exception): + pass class PruneWorker(object): - def __init__(self, op, pruned_params=[], visited={}): + def __init__(self, op, pruned_params, visited, skip_stranger=True): """ A wrapper of operator used to infer the information of all the related variables. Args: op(Operator): The operator to be pruned. - pruned_params(list): The list to store the information of pruning that infered by walker. + pruned_params(list): The list to store the information of pruning that infered by worker. visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name. + skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True. - Return: A instance of PruneWalker. + Return: A instance of PruneWorker. """ self.op = op self.pruned_params = pruned_params self.visited = visited + self.skip_stranger = skip_stranger + self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None) + self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip( + ).split(",") def prune(self, var, pruned_axis, pruned_idx): """ @@ -49,7 +77,7 @@ class PruneWorker(object): Args: var(Variable): The root variable of searching. It can be the input or output of current operator. pruned_axis(int): The axis to be pruned of root variable. - pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable. + pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable. """ if self._visit(var, pruned_axis): self._prune(var, pruned_axis, pruned_idx) @@ -82,29 +110,36 @@ class PruneWorker(object): return if visited is not None: self.visited = visited + if op.type() in self.ops_unsupported: + raise UnsupportOpError("Unsupported operator named {}".format( + op.type())) + cls = PRUNE_WORKER.get(op.type()) if cls is None: - if op.type() in SKIP_OPS: - _logger.warn("Skip operator [{}]".format(op.type())) + if op.type() in SKIPPED_OPS: return + if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger: + cls = PRUNE_WORKER.get("default_worker") + else: + raise UnsupportOpError("Unsupported operator named {}".format( + op.type())) -# _logger.warn( -# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". -# format(op.type())) - cls = PRUNE_WORKER.get("default_walker") - _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( - self.op, op, pruned_axis, var.name())) + _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}". + format(self.op, op, pruned_axis, var.name(), pruned_idx)) _logger.debug( f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n" ) - walker = cls(op, pruned_params=self.pruned_params, visited=self.visited) - walker.prune(var, pruned_axis, pruned_idx) + worker = cls(op, self.pruned_params, self.visited, self.skip_stranger) + worker.prune(var, pruned_axis, pruned_idx) + + def append_pruned_vars(self, var, axis, transforms): + self.pruned_params.append((var, axis, transforms, self.op)) @PRUNE_WORKER.register class conv2d(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(conv2d, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(conv2d, self).__init__(op, pruned_params, visited, skip_stranger) def _is_depthwise_conv(self, op): data_format = self.op.attr("data_format") @@ -121,15 +156,17 @@ class conv2d(PruneWorker): num_filters % num_channels == 0) def _prune(self, var, pruned_axis, pruned_idx): - if self._is_depthwise_conv(self.op): _logger.debug(f"Meet conv2d who is depthwise conv2d actually.") - walker = depthwise_conv2d( - self.op, self.pruned_params, visited=self.visited) - walker._prune(var, pruned_axis, pruned_idx) - return + worker = depthwise_conv2d( + self.op, + self.pruned_params, + visited=self.visited, + skip_stranger=self.skip_stranger) + return worker._prune(var, pruned_axis, pruned_idx) data_format = self.op.attr("data_format") + groups = self.op.attr("groups") channel_axis = 1 if data_format == "NHWC": channel_axis = 3 @@ -137,56 +174,49 @@ class conv2d(PruneWorker): assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( pruned_axis, var.name()) filter_var = self.op.inputs("Filter")[0] - self._visit(filter_var, 1) - self.pruned_params.append((filter_var, 1, pruned_idx)) - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 1, pruned_idx) + self.append_pruned_vars(filter_var, 1, pruned_idx) + if groups is None or groups == 1: + self._visit_and_search(filter_var, 1, pruned_idx) elif var in self.op.inputs("Filter"): assert pruned_axis in [0, 1] - self.pruned_params.append((var, pruned_axis, pruned_idx)) + self.append_pruned_vars(var, pruned_axis, pruned_idx) - for op in var.outputs(): - self._prune_op(op, var, pruned_axis, pruned_idx) + if groups is None or groups == 1 or pruned_axis == 0: + self._visit_and_search(var, pruned_axis, pruned_idx) if pruned_axis == 0: if len(self.op.inputs("Bias")) > 0: - self.pruned_params.append( - (self.op.inputs("Bias"), channel_axis, pruned_idx)) + self.append_pruned_vars( + self.op.inputs("Bias"), channel_axis, pruned_idx) output_var = self.op.outputs("Output")[0] - self._visit(output_var, channel_axis) - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) + self._visit_and_search(output_var, channel_axis, pruned_idx) elif pruned_axis == 1: input_var = self.op.inputs("Input")[0] - self._visit(input_var, channel_axis) - pre_ops = input_var.inputs() - for op in pre_ops: - self._prune_op(op, input_var, channel_axis, pruned_idx) + self._visit_and_search(input_var, channel_axis, pruned_idx) elif var in self.op.outputs("Output"): assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( pruned_axis, var.name()) filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 0) - - self.pruned_params.append((filter_var, 0, pruned_idx)) + self.append_pruned_vars(filter_var, 0, pruned_idx) for op in filter_var.outputs(): self._prune_op(op, filter_var, 0, pruned_idx) if len(self.op.inputs("Bias")) > 0: - self.pruned_params.append( - (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) + self.append_pruned_vars( + self.op.inputs("Bias")[0], channel_axis, pruned_idx) @PRUNE_WORKER.register class conv2d_transpose(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(conv2d_transpose, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(conv2d_transpose, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): data_format = self.op.attr("data_format") @@ -198,7 +228,7 @@ class conv2d_transpose(PruneWorker): pruned_axis, var.name()) filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 0) - self.pruned_params.append((filter_var, 0, pruned_idx)) + self.append_pruned_vars(filter_var, 0, pruned_idx) for op in filter_var.outputs(): self._prune_op(op, filter_var, 0, pruned_idx) @@ -212,14 +242,14 @@ class conv2d_transpose(PruneWorker): filter_var = self.op.inputs("Filter")[0] self._visit(filter_var, 1) - self.pruned_params.append((filter_var, 1, pruned_idx)) + self.append_pruned_vars(filter_var, 1, pruned_idx) for op in filter_var.outputs(): self._prune_op(op, filter_var, 1, pruned_idx) if len(self.op.inputs("Bias")) > 0: - self.pruned_params.append( - (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) + self.append_pruned_vars( + self.op.inputs("Bias")[0], channel_axis, pruned_idx) output_var = self.op.outputs("Output")[0] next_ops = output_var.outputs() @@ -229,8 +259,9 @@ class conv2d_transpose(PruneWorker): @PRUNE_WORKER.register class batch_norm(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(batch_norm, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(batch_norm, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if (var not in self.op.outputs("Y")) and ( @@ -248,7 +279,7 @@ class batch_norm(PruneWorker): param_var = self.op.inputs(param)[0] for op in param_var.outputs(): self._prune_op(op, param_var, 0, pruned_idx) - self.pruned_params.append((param_var, 0, pruned_idx)) + self.append_pruned_vars(param_var, 0, pruned_idx) out_var = self.op.outputs("Y")[0] self._visit(out_var, pruned_axis) @@ -259,13 +290,15 @@ class batch_norm(PruneWorker): @PRUNE_WORKER.register class sync_batch_norm(batch_norm): - def __init__(self, op, pruned_params, visited): - super(sync_batch_norm, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(sync_batch_norm, self).__init__(op, pruned_params, visited, + skip_stranger) class elementwise_op(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(elementwise_op, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(elementwise_op, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): axis = self.op.attr("axis") @@ -286,7 +319,7 @@ class elementwise_op(PruneWorker): # for bias if name == "Y" and actual_axis >= 0 and not ( len(in_var.shape()) == 1 and in_var.shape()[0] == 1): - self.pruned_params.append((in_var, actual_axis, pruned_idx)) + self.append_pruned_vars(in_var, actual_axis, pruned_idx) self._visit_and_search(in_var, actual_axis, pruned_idx) else: @@ -301,8 +334,7 @@ class elementwise_op(PruneWorker): if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1): - self.pruned_params.append( - (in_var, y_pruned_axis, pruned_idx)) + self.append_pruned_vars(in_var, y_pruned_axis, pruned_idx) self._visit_and_search(in_var, y_pruned_axis, pruned_idx) elif var in self.op.inputs("Y"): in_var = self.op.inputs("X")[0] @@ -318,26 +350,30 @@ class elementwise_op(PruneWorker): @PRUNE_WORKER.register class elementwise_add(elementwise_op): - def __init__(self, op, pruned_params, visited): - super(elementwise_add, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(elementwise_add, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class elementwise_sub(elementwise_op): - def __init__(self, op, pruned_params, visited): - super(elementwise_sub, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(elementwise_sub, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class elementwise_mul(elementwise_op): - def __init__(self, op, pruned_params, visited): - super(elementwise_mul, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(elementwise_mul, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class activation(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(activation, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(activation, self).__init__(op, pruned_params, visited, + skip_stranger) self.input_name = "X" self.output_name = "Out" @@ -351,9 +387,10 @@ class activation(PruneWorker): @PRUNE_WORKER.register -class default_walker(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(default_walker, self).__init__(op, pruned_params, visited) +class default_worker(PruneWorker): + def __init__(self, op, pruned_params, visited, skip_stranger): + super(default_worker, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.all_outputs(): @@ -367,59 +404,62 @@ class default_walker(PruneWorker): @PRUNE_WORKER.register class uniform_random_batch_size_like(activation): - def __init__(self, op, pruned_params, visited): - super(uniform_random_batch_size_like, self).__init__(op, pruned_params, - visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(uniform_random_batch_size_like, self).__init__( + op, pruned_params, visited, skip_stranger) self.input_name = "Input" self.output_name = "Out" @PRUNE_WORKER.register class bilinear_interp(activation): - def __init__(self, op, pruned_params, visited): - super(bilinear_interp, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(bilinear_interp, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class nearest_interp(activation): - def __init__(self, op, pruned_params, visited): - super(nearest_interp, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(nearest_interp, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class relu(activation): - def __init__(self, op, pruned_params, visited): - super(relu, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(relu, self).__init__(op, pruned_params, visited, skip_stranger) @PRUNE_WORKER.register class leaky_relu(activation): - def __init__(self, op, pruned_params, visited): - super(leaky_relu, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(leaky_relu, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class floor(activation): - def __init__(self, op, pruned_params, visited): - super(floor, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(floor, self).__init__(op, pruned_params, visited, skip_stranger) @PRUNE_WORKER.register class relu6(activation): - def __init__(self, op, pruned_params, visited): - super(relu6, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(relu6, self).__init__(op, pruned_params, visited, skip_stranger) @PRUNE_WORKER.register class pool2d(activation): - def __init__(self, op, pruned_params, visited): - super(pool2d, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(pool2d, self).__init__(op, pruned_params, visited, skip_stranger) @PRUNE_WORKER.register class sum(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(sum, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(sum, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.outputs("Out"): @@ -440,10 +480,46 @@ class sum(PruneWorker): self._prune_op(op, out_var, pruned_axis, pruned_idx) +@PRUNE_WORKER.register +class split(PruneWorker): + def __init__(self, op, pruned_params, visited, skip_stranger): + super(split, self).__init__(op, pruned_params, visited, skip_stranger) + self.in_var = op.inputs("X")[0] + self.out_vars = op.outputs("Out") + self.axis = op.attr("axis") + self.num = op.attr("num") + + def _prune(self, var, pruned_axis, transforms): + if var == self.in_var: + if pruned_axis != self.axis: + for out_var in self.out_vars: + self._visit_and_search(out_var, pruned_axis, transforms) + else: + raise UnsupportOpError( + "Unsupport pruning input of split operator directly.") + elif var in self.out_vars: + if pruned_axis != self.axis: + self._visit_and_search(self.in_var, pruned_axis, transforms) + else: + trans = { + "src_start": 0, + "src_end": var.shape()[pruned_axis], + "target_start": 0, + "target_end": self.in_var.shape()[pruned_axis], + "target_len": self.in_var.shape()[pruned_axis] + } + self._visit_and_search(self.in_var, pruned_axis, + transforms + [trans]) + + for out_var in self.out_vars: + if var != out_var: + self._visit_and_search(out_var, pruned_axis, transforms) + + @PRUNE_WORKER.register class concat(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(concat, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(concat, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, transforms): axis = self.op.attr("axis") @@ -513,52 +589,56 @@ class concat(PruneWorker): @PRUNE_WORKER.register class depthwise_conv2d(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(depthwise_conv2d, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(depthwise_conv2d, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, transforms): - assert var not in self.op.inputs( - "Filter"), "Unsupport for pruning depthwise conv2d directly." - assert var not in self.op.outputs( - "Output" - ), "Unsupport for pruning output of depthwise conv2d directly." + + _filter = self.op.inputs("Filter")[0] + _out = self.op.outputs("Output")[0] + _in_var = self.op.inputs("Input")[0] + data_format = self.op.attr("data_format") - groups = self.op.attr("groups") channel_axis = 1 if data_format == "NHWC": channel_axis = 3 - if var in self.op.inputs("Input"): + + if var == _in_var: assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) - - groups = var.shape()[channel_axis] - filter_var = self.op.inputs("Filter")[0] - transform = { - 'src_start': 0, - 'src_end': var.shape()[pruned_axis], - 'target_start': 0, - 'target_end': filter_var.shape()[0], - 'target_len': filter_var.shape()[0], - 'stride': 1 - } - - self.pruned_params.append((filter_var, 0, transforms + [transform])) - self._visit(filter_var, 0) - - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 0, transforms + [transform]) - - output_var = self.op.outputs("Output")[0] - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, - transforms + [transform]) + # pruning number of filters + self.append_pruned_vars(_filter, 0, transforms) + # kernel_number * groups will be pruned by reducing groups + self.append_pruned_vars(_filter, 1, transforms) + self._visit_and_search(_filter, 0, transforms) + # It will not pruning number of kernels in depthwise conv2d, + # so it is not neccesary to search succeed operators. + # self._visit_and_search(_filter, 1, transforms) + self._visit(_filter, 1) + self._visit_and_search(_out, channel_axis, transforms) + elif var == _filter: + assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." + self.append_pruned_vars(_filter, 1, transforms) + self._visit_and_search(_in_var, channel_axis, transforms) + self._visit_and_search(_out, channel_axis, transforms) + elif var == _out: + assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( + pruned_axis) + self.append_pruned_vars(_filter, 0, transforms) + self.append_pruned_vars(_filter, 1, transforms) + self._visit_and_search(_filter, 0, transforms) + # It will not pruning number of kernels in depthwise conv2d, + # so it is not neccesary to search succeed operators. + # self._visit_and_search(_filter, 1, transforms) + self._visit(_filter, 1) + self._visit_and_search(_in_var, channel_axis, transforms) @PRUNE_WORKER.register class mul(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(mul, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(mul, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.inputs("X"): @@ -570,7 +650,7 @@ class mul(PruneWorker): for i in pruned_idx: idx += list(range_idx + i * feature_map_size) param_var = self.op.inputs("Y")[0] - self.pruned_params.append((param_var, 0, idx)) + self.append_pruned_vars(param_var, 0, idx) for op in param_var.outputs(): self._prune_op(op, param_var, 0, pruned_idx) @@ -578,22 +658,36 @@ class mul(PruneWorker): @PRUNE_WORKER.register class matmul(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(matmul, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(matmul, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): - if var in self.op.inputs("X") and pruned_axis == 1: - param_var = self.op.inputs("Y")[0] - self.pruned_params.append((param_var, 0, pruned_idx)) + x = self.op.inputs("X")[0] + y = self.op.inputs("Y")[0] + out = self.op.outputs("Out")[0] + if var == x and pruned_axis == 1: + self.append_pruned_vars(y, 0, pruned_idx) + self._visit_and_search(y, 0, pruned_idx) + if var == out: + if pruned_axis == 0: + self.append_pruned_vars(x, 0, pruned_idx) + self._visit_and_search(x, 0, pruned_idx) + elif pruned_axis == 1: + self.append_pruned_vars(y, 1, pruned_idx) + self._visit_and_search(y, 1, pruned_idx) - for op in param_var.outputs(): - self._prune_op(op, param_var, 0, pruned_idx) + +@PRUNE_WORKER.register +class matmul_v2(matmul): + def __init__(self, op, pruned_params, visited, skip_stranger): + super(matmul_v2, self).__init__(op, pruned_params, visited, + skip_stranger) @PRUNE_WORKER.register class scale(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(scale, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(scale, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.inputs("X"): @@ -608,34 +702,34 @@ class scale(PruneWorker): @PRUNE_WORKER.register class momentum(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(momentum, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(momentum, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.inputs("Param"): - _logger.debug("pruning momentum, var:{}".format(var.name())) velocity_var = self.op.inputs("Velocity")[0] - self.pruned_params.append((velocity_var, pruned_axis, pruned_idx)) + self.append_pruned_vars(velocity_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register class adam(PruneWorker): - def __init__(self, op, pruned_params, visited={}): - super(adam, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(adam, self).__init__(op, pruned_params, visited, skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.inputs("Param"): - _logger.debug("pruning momentum, var:{}".format(var.name())) moment1_var = self.op.inputs("Moment1")[0] - self.pruned_params.append((moment1_var, pruned_axis, pruned_idx)) + self.append_pruned_vars(moment1_var, pruned_axis, pruned_idx) moment2_var = self.op.inputs("Moment2")[0] - self.pruned_params.append((moment2_var, pruned_axis, pruned_idx)) + self.append_pruned_vars(moment2_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register class affine_channel(PruneWorker): - def __init__(self, op, pruned_params, visited): - super(affine_channel, self).__init__(op, pruned_params, visited) + def __init__(self, op, pruned_params, visited, skip_stranger): + super(affine_channel, self).__init__(op, pruned_params, visited, + skip_stranger) def _prune(self, var, pruned_axis, pruned_idx): if (var not in self.op.outputs("Out")) and ( @@ -653,7 +747,7 @@ class affine_channel(PruneWorker): param_var = self.op.inputs(param)[0] for op in param_var.outputs(): self._prune_op(op, param_var, 0, pruned_idx) - self.pruned_params.append((param_var, 0, pruned_idx)) + self.append_pruned_vars(param_var, 0, pruned_idx) out_var = self.op.outputs("Out")[0] self._visit(out_var, pruned_axis) @@ -664,11 +758,12 @@ class affine_channel(PruneWorker): @PRUNE_WORKER.register class flatten_contiguous_range(PruneWorker): - def __init__(self, op, pruned_params, visited): + def __init__(self, op, pruned_params, visited, skip_stranger): super(flatten_contiguous_range, self).__init__(op, pruned_params, - visited) + visited, skip_stranger) def _prune(self, var, pruned_axis, transforms): + start_axis = self.op.attr("start_axis") stop_axis = self.op.attr("stop_axis") if var in self.op.inputs("X"): @@ -690,3 +785,58 @@ class flatten_contiguous_range(PruneWorker): for op in next_ops: self._prune_op(op, out_var, out_pruned_axis, transforms + [transform]) + + +@PRUNE_WORKER.register +class squeeze2(PruneWorker): + def __init__(self, op, pruned_params, visited, skip_stranger): + super(squeeze2, self).__init__(op, pruned_params, visited, + skip_stranger) + + def _prune(self, var, pruned_axis, transforms): + + axes = self.op.attr("axes") + in_var = self.op.inputs("X")[0] + out_var = self.op.outputs("Out")[0] + if axes is None or len(axes) == 0: + axes = [i for i, axis in enumerate(in_var.shape()) if axis == 1] + squeeze_num = 0 + if in_var == var: + for axis in axes: + assert axis != pruned_axis, "Can not pruning axis that will be squeezed." + if axis < pruned_axis: + squeeze_num += 1 + pruned_axis -= squeeze_num + self._visit_and_search(out_var, pruned_axis, transforms) + elif out_var == var: + for axis in axes: + if axis <= pruned_axis: + pruned_axis += 1 + self._visit_and_search(in_var, pruned_axis, transforms) + + +@PRUNE_WORKER.register +class unsqueeze2(PruneWorker): + def __init__(self, op, pruned_params, visited, skip_stranger): + super(unsqueeze2, self).__init__(op, pruned_params, visited, + skip_stranger) + + def _prune(self, var, pruned_axis, transforms): + + axes = self.op.attr("axes") + in_var = self.op.inputs("X")[0] + out_var = self.op.outputs("Out")[0] + assert (axes is not None) + + squeeze_num = 0 + if in_var == var: + for axis in axes: + if axis <= pruned_axis: + pruned_axis += 1 + self._visit_and_search(out_var, pruned_axis, transforms) + elif out_var == var: + for axis in axes: + if axis < pruned_axis: + squeeze_num += 1 + pruned_axis -= squeeze_num + self._visit_and_search(in_var, pruned_axis, transforms) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 60fd885e..f8947490 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -18,7 +18,7 @@ import copy import numpy as np from functools import reduce from ..core import VarWrapper, OpWrapper, GraphWrapper -from .group_param import collect_convs +from .collections import StaticPruningCollections from .criterion import CRITERION from .idx_selector import IDX_SELECTOR from ..common import get_logger @@ -79,38 +79,28 @@ class Pruner(): Returns: tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters. """ - self.pruned_list = [] graph = GraphWrapper(program.clone()) param_backup = {} if param_backup else None param_shape_backup = {} if param_shape_backup else None pruned_params = [] - visited = {} - for param, ratio in zip(params, ratios): - _logger.info("pruning: {}".format(param)) - if graph.var(param) is None: - _logger.warn( - "Variable[{}] to be pruned is not in current graph.".format( - param)) - continue - group = collect_convs([param], graph, - visited)[0] # [(name, axis, pruned_idx)] - if group is None or len(group) == 0: - continue - assert ( - not self.pruned_weights), "The weights have been pruned once." - group_values = [] - for name, axis, pruned_idx in group: - var = scope.find_var(name) + collections = StaticPruningCollections(params, graph) + ratios = dict(zip(params, ratios)) + values = {} + for _collection in collections: + for _var_name in _collection.variables(): + var = scope.find_var(_var_name) if var is not None: - values = np.array(var.get_tensor()) - group_values.append((name, values, axis, pruned_idx)) + value = np.array(var.get_tensor()) + values[_var_name] = value - scores = self.criterion(group_values, - graph) # [(name, axis, score, pruned_idx)] - g = self._transform(self.idx_selector(scores, ratio)) - pruned_params.extend(g) + for _collection in collections: + scores = self.criterion(_collection, values, graph) + idx = self.idx_selector(_collection, scores, + ratios) # name, axis, idx, transform + idx = self._transform(idx) + pruned_params.extend(idx) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: @@ -124,32 +114,35 @@ class Pruner(): pruned_idx = np.concatenate(merge_pruned_params[param_name][ pruned_axis]) param = graph.var(param_name) + _groups = 1 if not lazy: - _logger.debug("{}\t{}\t{}\t{}".format( - param.name(), pruned_axis, - param.shape()[pruned_axis], len(pruned_idx))) - origin_shape = copy.deepcopy(param.shape()) - if param_shape_backup is not None: - param_shape_backup[param.name()] = origin_shape - new_shape = list(param.shape()) - new_shape[pruned_axis] -= len(pruned_idx) - param.set_shape(new_shape) - # update groups of depthwise conv2d - for op in param.outputs(): - if op.type() in ["conv2d", "depthwise_conv2d" - ] and op.attr("groups") > 1: - assert origin_shape[ - 1] == 1, "Only support for depthwise when groups > 1." - new_groups = int( - op.attr("groups") * new_shape[pruned_axis] / - origin_shape[pruned_axis]) - _logger.debug( - f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}" - ) - op.set_attr("groups", new_groups) - - if not only_graph: - param_t = scope.find_var(param.name()).get_tensor() + # update groups of conv2d + if pruned_axis == 1: + for op in param.outputs(): + if op.type() in ["conv2d", "depthwise_conv2d" + ] and op.attr("groups") > 1: + _groups = op.attr("groups") + _filter_num = param.shape()[1] + new_groups = int( + (_groups * _filter_num - len(pruned_idx)) / + _filter_num) + _logger.info( + f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};" + ) + op.set_attr("groups", new_groups) + if _groups == 1: + origin_shape = copy.deepcopy(param.shape()) + if param_shape_backup is not None: + param_shape_backup[param.name()] = origin_shape + new_shape = list(param.shape()) + new_shape[pruned_axis] -= len(pruned_idx) + param.set_shape(new_shape) + + if not only_graph and (_groups == 1 or pruned_axis != 1): + _var = scope.find_var(param.name()) + if _var is None: + continue + param_t = _var.get_tensor() if param_backup is not None and ( param.name() not in param_backup): param_backup[param.name()] = copy.deepcopy( @@ -162,40 +155,42 @@ class Pruner(): lazy=lazy) param_t.set(pruned_param, place) except IndexError as e: - _logger.error("Pruning {}, but get [{}]".format( - param.name(), e)) + _logger.error( + "Pruning {} with shape {} on axis {}, but get [{}]; ". + format(param.name(), + param_t.shape(), pruned_axis, e)) graph.infer_shape() self.pruned_weights = (not only_graph) return graph.program, param_backup, param_shape_backup - def _transform(self, group): + def _transform(self, items): ret = [] - for name, axis, pruned_idx, transforms in group: + for name, axis, pruned_idx, transforms in items: src = pruned_idx for trans in transforms: src_start = trans['src_start'] src_end = trans['src_end'] + src_len = src_end - src_start target_start = trans['target_start'] target_end = trans['target_end'] + starts = np.array(range(target_start, target_end, src_len)) target = [] for idx in src: if idx >= src_start and idx < src_end: idx -= src_start - idx += target_start - if idx < target_end: - target.append(idx) + target.extend(list(idx + starts)) src = target ret.append((name, axis, src)) return ret def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): """ - Pruning a array by indexes on given axis. + Pruning a array by indices on given axis. Args: tensor(numpy.array): The target array to be pruned. - pruned_idx(list): The indexes to be pruned. + pruned_idx(list): The indices to be pruned. pruned_axis(int): The axis of given array to be pruned on. lazy(bool): True means setting the pruned elements to zero. False means remove the pruned elements from memory. diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index ce220e9a..33a9934b 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -98,7 +98,7 @@ def sensitivity(program, params=[name], ratios=[ratio], place=place, - lazy=True, + lazy=False, only_graph=False, param_backup=True) if eval_args is None: @@ -108,7 +108,6 @@ def sensitivity(program, loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, loss)) - sensitivities[name][ratio] = loss _save_sensitivities(sensitivities, sensitivities_file) diff --git a/tests/dygraph/test_filter_pruner.py b/tests/dygraph/test_filter_pruner.py index a6b123ad..f909e382 100644 --- a/tests/dygraph/test_filter_pruner.py +++ b/tests/dygraph/test_filter_pruner.py @@ -99,13 +99,74 @@ class TestFilterPruner(unittest.TestCase): plan = pruner.sensitive_prune(0.01, align=4) for param in net.parameters(): if param.name in self._param_names: + print(f"name: {param.name}; shape: {param.shape}") self.assertTrue(param.shape[0] % 4 == 0) pruner.restore() +class TestPruningGroupConv2d(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(TestPruningGroupConv2d, self).__init__(methodName) + + def runTest(self): + with fluid.unique_name.guard(): + net = paddle.vision.models.mobilenet_v1() + ratios = {} + for param in net.parameters(): + if len(param.shape) == 4: + ratios[param.name] = 0.5 + pruners = [] + pruner = L1NormFilterPruner(net, [1, 3, 128, 128]) + pruners.append(pruner) + pruner = FPGMFilterPruner(net, [1, 3, 128, 128]) + pruners.append(pruner) + pruner = L2NormFilterPruner(net, [1, 3, 128, 128]) + pruners.append(pruner) + + shapes = {} + for pruner in pruners: + plan = pruner.prune_vars(ratios, 0) + for param in net.parameters(): + if param.name not in shapes: + shapes[param.name] = param.shape + assert (shapes[param.name] == param.shape) + pruner.restore() + + +#class TestStrideTransform(unittest.TestCase): +# def __init__(self, methodName='runTest'): +# super(TestStrideTransform, self).__init__(methodName) +# +# def runTest(self): +# with fluid.unique_name.guard(): +# +# net = paddle.vision.models.mobilenet_v1() +# ratios = {} +# for param in net.parameters(): +# if len(param.shape) == 4: +# ratios[param.name] = 0.5 +# pruners = [] +# pruner = L1NormFilterPruner(net, [1, 3, 128, 128]) +# pruners.append(pruner) +# pruner = FPGMFilterPruner(net, [1, 3, 128, 128]) +# pruners.append(pruner) +# pruner = L2NormFilterPruner(net, [1, 3, 128, 128]) +# pruners.append(pruner) +# +# shapes = {} +# for pruner in pruners: +# plan = pruner.prune_vars(ratios, 0) +# for param in net.parameters(): +# if param.name not in shapes: +# shapes[param.name] = param.shape +# assert(shapes[param.name] == param.shape) +# pruner.restore() + + def add_cases(suite): - suite.addTest(TestStatus()) - suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"])) + # suite.addTest(TestStatus()) + # suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"])) + suite.addTest(TestPruningGroupConv2d()) def load_tests(loader, standard_tests, pattern): diff --git a/tests/dygraph/test_prune.py b/tests/dygraph/test_prune.py index 6f562751..ade2ce71 100644 --- a/tests/dygraph/test_prune.py +++ b/tests/dygraph/test_prune.py @@ -43,7 +43,7 @@ class TestPrune(unittest.TestCase): paddle.disable_static() model = net(pretrained=False) pruner = L1NormFilterPruner(model, [1, 3, 16, 16]) - pruner.prune_vars(ratios, [0]) + pruner.prune_vars(ratios, 0) shapes = {} for param in model.parameters(): shapes[param.name] = param.shape diff --git a/tests/dygraph/test_prune_walker.py b/tests/dygraph/test_prune_walker.py index 0c18ffc0..2d28b04c 100644 --- a/tests/dygraph/test_prune_walker.py +++ b/tests/dygraph/test_prune_walker.py @@ -25,7 +25,7 @@ class TestWalker(unittest.TestCase): net = Net() x = np.random.uniform(-1, 1, x_shape).astype('float32') pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)]) - pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0]) + pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0) self.assertTrue(net.linear.weight.shape == [5400, 5]) diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py index fda40b7d..51d2f75c 100644 --- a/tests/test_dygraph_pruning_plan.py +++ b/tests/test_dygraph_pruning_plan.py @@ -8,14 +8,14 @@ from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask class TestPruningPlan(unittest.TestCase): def testAdd(self): plan = PruningPlan() - mask = PruningMask([0], [0, 1, 1], 0.33) + mask = PruningMask(0, [0, 1, 1], 0.33, None) plan.add("a", mask) - mask = PruningMask([0], [0, 1, 0], 0.33) + mask = PruningMask(0, [0, 1, 0], 0.33, None) plan.add("a", mask) a_mask = plan.masks["a"] self.assertTrue(len(a_mask) == 1) self.assertTrue(a_mask[0].mask == [0, 1, 0]) - self.assertTrue(a_mask[0].dims == [0]) + self.assertTrue(a_mask[0].dims == 0) if __name__ == '__main__': diff --git a/tests/test_group_param.py b/tests/test_group_param.py index dcdf6eab..71fe237a 100644 --- a/tests/test_group_param.py +++ b/tests/test_group_param.py @@ -16,7 +16,7 @@ sys.path.append("../") import unittest import paddle.fluid as fluid from layers import conv_bn_layer -from paddleslim.prune import collect_convs +from paddleslim.prune import StaticPruningCollections from static_case import StaticCase @@ -41,12 +41,9 @@ class TestPrune(StaticCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") - collected_groups = collect_convs( + collections = StaticPruningCollections( ["conv1_weights", "conv2_weights", "conv3_weights", "dummy"], main_program) - while [] in collected_groups: - collected_groups.remove([]) - print(collected_groups) params = set([ param.name for param in main_program.all_parameters() @@ -58,14 +55,13 @@ class TestPrune(StaticCase): ('conv4_weights', 0), ('conv5_weights', 1)], [('conv3_weights', 0), ('conv4_weights', 1)]] - self.assertTrue(len(collected_groups) == len(expected_groups)) - for _collected, _expected in zip(collected_groups, expected_groups): - for _name, _axis, _ in _collected: + self.assertTrue(len(collections._collections) == len(expected_groups)) + for _collected, _expected in zip(collections, expected_groups): + for _info in _collected.all_pruning_details(): + _name = _info.name + _axis = _info.axis if _name in params: self.assertTrue((_name, _axis) in _expected) - for _name, _axis in _expected: - if _name in params: - self.assertTrue((_name, _axis, []) in _collected) if __name__ == '__main__': diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index 8184892b..51b84ecc 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +import os sys.path.append("../") import unittest import numpy as np @@ -22,6 +23,7 @@ from static_case import StaticCase from layers import conv_bn_layer import random from paddleslim.core import GraphWrapper +from paddleslim.prune.prune_worker import * class TestPrune(StaticCase): @@ -35,53 +37,54 @@ class TestPrune(StaticCase): # # X: prune output channels # O: prune input channels - with fluid.program_guard(main_program, startup_program): - input = fluid.data(name="image", shape=[None, 3, 16, 16]) - label = fluid.data(name='label', shape=[None, 1], dtype='int64') - conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu') - conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu') - sum1 = conv1 + conv2 - conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6') - conv4 = conv_bn_layer(conv3, 8, 3, "conv4") - sum2 = conv4 + sum1 - conv5 = conv_bn_layer(sum2, 8, 3, "conv5") - - flag = fluid.layers.fill_constant([1], value=1, dtype='int32') - rand_flag = paddle.randint(2, dtype='int32') - cond = fluid.layers.less_than(x=flag, y=rand_flag) - cond_output = fluid.layers.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=False, - name='cond_output') - - def cond_block1(): - cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1") - fluid.layers.assign(input=cond_conv, output=cond_output) - - def cond_block2(): - cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1") - cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2") - fluid.layers.assign(input=cond_conv2, output=cond_output) - - fluid.layers.cond(cond, cond_block1, cond_block2) - sum3 = fluid.layers.sum([sum2, cond_output]) - - conv6 = conv_bn_layer(sum3, 8, 3, "conv6") - sub1 = conv6 - sum3 - mult = sub1 * sub1 - conv7 = conv_bn_layer( - mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False) - floored = fluid.layers.floor(conv7) - scaled = fluid.layers.scale(floored) - concated = fluid.layers.concat([scaled, mult], axis=1) - conv8 = conv_bn_layer(concated, 8, 3, "conv8") - predict = fluid.layers.fc(input=conv8, size=10, act='softmax') - cost = fluid.layers.cross_entropy(input=predict, label=label) - adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) - avg_cost = fluid.layers.mean(cost) - adam_optimizer.minimize(avg_cost) + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu') + conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu') + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6') + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + + flag = fluid.layers.fill_constant([1], value=1, dtype='int32') + rand_flag = paddle.randint(2, dtype='int32') + cond = fluid.layers.less_than(x=flag, y=rand_flag) + cond_output = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=False, + name='cond_output') + + def cond_block1(): + cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1") + fluid.layers.assign(input=cond_conv, output=cond_output) + + def cond_block2(): + cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1") + cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2") + fluid.layers.assign(input=cond_conv2, output=cond_output) + + fluid.layers.cond(cond, cond_block1, cond_block2) + sum3 = fluid.layers.sum([sum2, cond_output]) + + conv6 = conv_bn_layer(sum3, 8, 3, "conv6") + sub1 = conv6 - sum3 + mult = sub1 * sub1 + conv7 = conv_bn_layer( + mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False) + floored = fluid.layers.floor(conv7) + scaled = fluid.layers.scale(floored) + concated = fluid.layers.concat([scaled, mult], axis=1) + conv8 = conv_bn_layer(concated, 8, 3, "conv8") + predict = fluid.layers.fc(input=conv8, size=10, act='softmax') + cost = fluid.layers.cross_entropy(input=predict, label=label) + adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) + avg_cost = fluid.layers.mean(cost) + adam_optimizer.minimize(avg_cost) params = [] for param in main_program.all_parameters(): @@ -117,5 +120,439 @@ class TestPrune(StaticCase): fetch_list=[cost.name]) +class TestUnsqueeze2(StaticCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu') + out = paddle.unsqueeze(conv1, axis=[0]) + + graph = GraphWrapper(main_program) + cls = PRUNE_WORKER.get("unsqueeze2") + out_var = graph.var(out.name) + in_var = graph.var(conv1.name) + op = out_var.inputs()[0] + # pruning out + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(out_var, 2, []) + for var, axis, _, _ in pruned_params: + ret[var.name()] = axis + self.assertTrue(ret == { + 'conv1_weights': 0, + 'conv1_bn_scale': 0, + 'conv1_bn_offset': 0, + 'conv1_bn_mean': 0, + 'conv1_bn_variance': 0 + }) + + # pruning in + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(in_var, 1, []) + for var, axis, _, _ in pruned_params: + ret[var.name()] = axis + self.assertTrue(ret == {}) + + +class TestSqueeze2(StaticCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[1, 3, 16, 16]) + conv1 = conv_bn_layer( + input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1] + out = paddle.squeeze(conv1) + + graph = GraphWrapper(main_program) + cls = PRUNE_WORKER.get("squeeze2") + out_var = graph.var(out.name) + in_var = graph.var(conv1.name) + op = out_var.inputs()[0] + # pruning out + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(out_var, 0, []) + for var, axis, _, _ in pruned_params: + ret[var.name()] = axis + self.assertTrue(ret == { + 'conv1_weights': 0, + 'conv1_bn_scale': 0, + 'conv1_bn_offset': 0, + 'conv1_bn_mean': 0, + 'conv1_bn_variance': 0 + }) + + # pruning in + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(in_var, 1, []) + for var, axis, _, _ in pruned_params: + ret[var.name()] = axis + self.assertTrue(ret == {}) + + +class TestUnsupportAndDefault(StaticCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[1, 3, 16, 16]) + conv1 = conv_bn_layer( + input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1] + # hit default pruning worker + cast1 = paddle.cast(conv1, dtype="int32") + # hit unsupported pruning worker + out = paddle.reshape(cast1, shape=[1, -1]) + + graph = GraphWrapper(main_program) + cls = PRUNE_WORKER.get("conv2d") + in_var = graph.var("conv1_weights") + op = in_var.outputs()[0] + # pruning input of conv op + pruned_params = [] + ret = {} + os.environ['OPS_UNSUPPORTED'] = "reshape2" + worker = cls(op, pruned_params, {}, True) + hit_unsupported_op = False + try: + worker.prune(in_var, 0, []) + except UnsupportOpError as e: + hit_unsupported_op = True + print(e) + self.assertTrue(hit_unsupported_op) + + +class TestConv2d(StaticCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[1, 3, 16, 16]) + + conv1 = conv_bn_layer( + input, 6, 3, "conv1", groups=1, bias=True, act='relu') + + graph = GraphWrapper(main_program) + cls = PRUNE_WORKER.get("conv2d") + weight_var = graph.var("conv1_weights") + in_var = graph.var("image") + op = in_var.outputs()[0] + out_var = op.outputs("Output")[0] + # pruning weights of conv op + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, {}, True) + worker.prune(weight_var, 0, []) + worker.prune(weight_var, 1, []) + for var, axis, _, _ in pruned_params: + if var.name() not in ret: + ret[var.name()] = [] + ret[var.name()].append(axis) + self.assertTrue(ret == { + 'conv1_weights': [0, 1], + 'conv1_out.b_0': [0], + 'conv1_bn_scale': [0], + 'conv1_bn_offset': [0], + 'conv1_bn_mean': [0], + 'conv1_bn_variance': [0] + }) + # pruning out of conv op + pruned_params = [] + ret = {} + worker = cls(op, pruned_params, visited={}, skip_stranger=True) + worker.prune(out_var, 1, []) + for var, axis, _, _ in pruned_params: + if var.name() not in ret: + ret[var.name()] = [] + ret[var.name()].append(axis) + self.assertTrue(ret == {'conv1_weights': [0]}) + + +class TestPruneWorker(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.create_graph() + self.cases = [] + self.set_cases() + + def define_layer(self, input): + pass + + def set_cases(self): + pass + + def create_graph(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with paddle.fluid.unique_name.guard(): + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[8, 8, 16, 16]) + self.define_layer(input) + self.graph = GraphWrapper(main_program) + self.in_var = self.graph.var(self.input.name) + self.out_var = self.graph.var(self.output.name) + self.op = self.in_var.outputs()[0] + + def check_in_out(self): + cls = PRUNE_WORKER.get(self.op.type()) + if cls is None: + cls = PRUNE_WORKER.get("default_worker") + + # pruning input of conv op + for _var, _axis, _ret in self.cases: + pruned_params = [] + ret = {} + worker = cls(self.op, pruned_params, visited={}, skip_stranger=True) + try: + worker.prune(_var, _axis, []) + except UnsupportOpError as e: + print(e) + continue + for var, axis, _, _ in pruned_params: + if var.name() not in ret: + ret[var.name()] = [] + ret[var.name()].append(axis) + self.assertTrue(ret == _ret) + + +class TestConv2dTranspose(TestPruneWorker): + def define_layer(self, input): + self.input = input + conv1 = paddle.static.nn.conv2d_transpose( + input, 6, 16, 3, name="conv1", bias_attr=False) + self.output = conv1 + return conv1 + + def set_cases(self): + self.cases.append((self.in_var, 1, {'conv1.w_0': [0]})) + self.cases.append((self.out_var, 1, {'conv1.w_0': [1]})) + + def test_prune(self): + self.check_in_out() + + +class TestElementwiseMul(TestPruneWorker): + def define_layer(self, input): + conv1 = paddle.static.nn.conv2d( + input, 3, 3, name="conv1", bias_attr=False) + conv2 = paddle.static.nn.conv2d( + input, 3, 3, name="conv2", bias_attr=False) + self.input = conv1 + out = conv1 * conv2 + conv3 = paddle.static.nn.conv2d( + out, 3, 3, name="conv3", bias_attr=False) + self.output = out + + def set_cases(self): + self.cases.append((self.in_var, 1, { + 'conv2.tmp_0': [1], + 'conv2.w_0': [0], + 'conv3.w_0': [1] + })) + self.cases.append((self.out_var, 1, { + 'conv1.w_0': [0], + 'conv2.tmp_0': [1], + 'conv2.w_0': [0] + })) + + def test_prune(self): + self.check_in_out() + + +class TestActivation(TestPruneWorker): + def __init__(self, methodName="test_prune", + op=paddle.nn.functional.sigmoid): + super(TestActivation, self).__init__(methodName) + self.act = op + + def define_layer(self, input): + conv1 = paddle.static.nn.conv2d( + input, 3, 3, name="conv1", bias_attr=False) + self.input = conv1 + tmp = self.act(conv1) + self.output = tmp + conv2 = paddle.static.nn.conv2d( + tmp, 3, 3, name="conv2", bias_attr=False) + + def set_cases(self): + self.cases.append((self.in_var, 1, {'conv2.w_0': [1]})) + self.cases.append((self.out_var, 1, { + 'conv1.w_0': [0], + 'conv2.w_0': [1] + })) + + def test_prune(self): + self.check_in_out() + + +suite = unittest.TestSuite() +suite.addTest(TestActivation(op=paddle.fluid.layers.resize_bilinear)) +suite.addTest(TestActivation(op=paddle.fluid.layers.resize_nearest)) +suite.addTest(TestActivation(op=paddle.floor)) +suite.addTest(TestActivation(op=paddle.scale)) +suite.addTest( + TestActivation(op=paddle.fluid.layers.nn.uniform_random_batch_size_like)) + + +class TestDepthwiseConv2d(TestPruneWorker): + def __init__(self, methodName="test_prune"): + super(TestDepthwiseConv2d, self).__init__(methodName) + + def define_layer(self, input): + self.input = input + conv1 = paddle.static.nn.conv2d( + input, + input.shape[1], + 3, + groups=input.shape[1], + name="conv1", + bias_attr=False) + self.output = conv1 + + def set_cases(self): + weight_var = self.graph.var('conv1.w_0') + self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]})) + self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]})) + self.cases.append((weight_var, 0, {'conv1.w_0': [1]})) + + def test_prune(self): + self.check_in_out() + + +class TestMul(TestPruneWorker): + def __init__(self, methodName="test_prune"): + super(TestMul, self).__init__(methodName) + + def define_layer(self, input): + x = fluid.data(name="x", shape=[1, 4, 3, 3]) + y = fluid.data(name="y", shape=[36, 7]) + self.input = x + out = paddle.fluid.layers.mul(x, y) + self.output = out + + def set_cases(self): + self.cases.append((self.in_var, 1, {'y': [0]})) + + def test_prune(self): + self.check_in_out() + + +class TestMatmul(TestPruneWorker): + def __init__(self, methodName="test_prune"): + super(TestMatmul, self).__init__(methodName) + + def define_layer(self, input): + x = fluid.data(name="x", shape=[6, 8]) + y = fluid.data(name="y", shape=[8, 7]) + self.input = x + out = paddle.matmul(x, y) + self.output = out + + def set_cases(self): + self.cases.append((self.in_var, 1, {'y': [0]})) + self.cases.append((self.out_var, 0, {'x': [0]})) + self.cases.append((self.out_var, 1, {'y': [1]})) + + def test_prune(self): + self.check_in_out() + + +class TestSplit(TestPruneWorker): + def define_layer(self, input): + self.input = input + split1 = paddle.split(input, num_or_sections=2, axis=1, name=None) + self.output = split1[0] + + def set_cases(self): + self.cases.append((self.in_var, 1, {})) + self.cases.append((self.in_var, 0, {})) + self.cases.append((self.out_var, 1, {})) + self.cases.append((self.out_var, 0, {})) + + def test_prune(self): + self.check_in_out() + + +class TestMomentum(TestPruneWorker): + def define_layer(self, input): + self.input = input + conv1 = paddle.static.nn.conv2d( + input, 3, 8, name="conv1", bias_attr=False) + self.output = conv1 + out = paddle.mean(conv1) + opt = paddle.optimizer.Momentum() + opt.minimize(out) + + def set_cases(self): + weight_var = self.graph.var('conv1.w_0') + self.cases.append((weight_var, 0, { + 'conv1.w_0': [0], + 'conv1.w_0_velocity_0': [0] + })) + + def test_prune(self): + self.check_in_out() + + +class TestAdam(TestPruneWorker): + def define_layer(self, input): + self.input = input + conv1 = paddle.static.nn.conv2d( + input, 3, 8, name="conv1", bias_attr=False) + self.output = conv1 + out = paddle.mean(conv1) + opt = paddle.optimizer.Adam() + opt.minimize(out) + + def set_cases(self): + weight_var = self.graph.var('conv1.w_0') + self.cases.append((weight_var, 0, { + 'conv1.w_0': [0], + 'conv1.w_0_moment1_0': [0], + 'conv1.w_0_moment2_0': [0] + })) + + def test_prune(self): + self.check_in_out() + + +class TestAffineChannel(TestPruneWorker): + def __init__(self, methodName="test_prune"): + super(TestAffineChannel, self).__init__(methodName) + + def define_layer(self, input): + conv1 = paddle.static.nn.conv2d( + input, 3, 8, name="conv1", bias_attr=False) + + self.input = conv1 + scale = fluid.data(name="scale", shape=[conv1.shape[1]]) + bias = fluid.data(name="bias", shape=[conv1.shape[1]]) + out = paddle.fluid.layers.affine_channel(conv1, scale=scale, bias=bias) + self.output = out + + def set_cases(self): + self.cases.append((self.in_var, 1, {'scale': [0], 'bias': [0]})) + self.cases.append((self.out_var, 1, { + 'conv1.w_0': [0], + 'scale': [0], + 'bias': [0] + })) + + def test_prune(self): + self.check_in_out() + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index 539ce67e..c3acf022 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -107,6 +107,7 @@ class TestSensitivity(StaticCase): sensitivities_file="./sensitivities_file_2", pruned_ratios=[0.1, 0.2, 0.3, 0.4]) self.assertTrue(params_sens == origin_sens) + self.assertTrue(sens == origin_sens) -- GitLab