From e3226b49ea01b943aef967ba5668b7b0454c436b Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 18 Jan 2021 09:11:15 +0800 Subject: [PATCH] Refine pruning code for detection (#554) (#600) --- paddleslim/analysis/flops.py | 28 ++++- paddleslim/core/__init__.py | 13 +- paddleslim/core/dygraph.py | 132 +++++++++++++++++++ paddleslim/core/graph_wrapper.py | 6 - paddleslim/dygraph/filter_pruner.py | 93 +++++++++----- paddleslim/dygraph/fpgm_pruner.py | 12 +- paddleslim/dygraph/l1norm_pruner.py | 10 +- paddleslim/dygraph/l2norm_pruner.py | 11 +- paddleslim/dygraph/pruner.py | 6 +- paddleslim/dygraph/pruning_plan.py | 53 ++++---- paddleslim/dygraph/var_group.py | 36 ++++-- paddleslim/prune/group_param.py | 3 +- paddleslim/prune/idx_selector.py | 10 +- paddleslim/prune/prune_io.py | 25 +++- paddleslim/prune/prune_walker.py | 188 ++++++++++++++-------------- paddleslim/prune/pruner.py | 77 ++++++++---- tests/dygraph/test_flops.py | 58 +++++++++ tests/dygraph/test_prune.py | 1 + tests/layers.py | 28 +++-- tests/test_dygraph_pruning_plan.py | 22 ++++ tests/test_group_param.py | 30 +++-- tests/test_prune_op.py | 107 ++++++++++++++++ 22 files changed, 703 insertions(+), 246 deletions(-) create mode 100644 paddleslim/core/dygraph.py create mode 100644 tests/test_dygraph_pruning_plan.py create mode 100644 tests/test_prune_op.py diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index e95a43b7..7e01d12c 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -14,7 +14,7 @@ import paddle import numpy as np import paddle.jit as jit -from ..core import GraphWrapper +from ..core import GraphWrapper, dygraph2program __all__ = ["flops", "dygraph_flops"] @@ -83,11 +83,27 @@ def _graph_flops(graph, only_conv=True, detail=False): return flops -def dygraph_flops(model, input_shape, only_conv=False, detail=False): +def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False): + """ + Compute the FLOPs of nn.Layer. + Args: + model(nn.Layer): The target model. + inputs(list): The dummy inputs used for 'model.forward'. It can be: + 1. list|tuple: means 'model.forward' accepts + only one variable as argument and the shape of + variable is 'inputs'. + 2. list>: means 'model.forward' accepts multiple + variables as arguments and the shapes of variables is 'inputs'. + 3. others: 'inputs' will be used as argument list by calling + 'model.forward(*inputs)'. + dtypes(str|list): It only used when 'inputs' is shape or shapes that means + data type of each input. None means all the inputs is 'float32'. + Default: None. + only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true. + default: True. + detail(bool): Whether to return detail of each convolution layer. + """ - data = np.ones(tuple(input_shape)).astype("float32") - in_var = paddle.to_tensor(data) - _, traced = paddle.jit.TracedLayer.trace(model, [in_var]) - program = traced.program + program = dygraph2program(model, inputs) graph = GraphWrapper(program) return _graph_flops(graph, only_conv=only_conv, detail=detail) diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index 1f9fab5c..e3076a7a 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper -from .registry import Registry +from ..core import graph_wrapper +from .graph_wrapper import * +from ..core import registry +from .registry import * +from ..core import dygraph +from .dygraph import * -__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry'] +__all__ = [] +__all__ += graph_wrapper.__all__ +__all__ += registry.__all__ +__all__ += dygraph.__all__ diff --git a/paddleslim/core/dygraph.py b/paddleslim/core/dygraph.py new file mode 100644 index 00000000..900bfddd --- /dev/null +++ b/paddleslim/core/dygraph.py @@ -0,0 +1,132 @@ +import paddle +import collections +import logging +import numpy as np +from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard +from paddle.fluid.dygraph.base import program_desc_tracing_guard +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid.framework import Block, ParamBase, Program, Variable +from ..common import get_logger + +__all__ = ["dygraph2program"] + +_logger = get_logger(__name__, level=logging.INFO) + + +def _is_shape(values): + if not isinstance(values, (list, tuple)): + return False + for v in values: + if not isinstance(v, int): + return False + return True + + +def _is_shapes(values): + if not isinstance(values, (list, tuple)): + return False + for v in values: + if not _is_shape(v): + return False + return True + + +def _create_tensors(shapes, dtypes=None): + if dtypes is not None: + assert len(shapes) == len( + dtypes + ), "Length of shapes and dtypes must be same. But get len(shapes): {}; len(dtypes): {}; shapes: {}; dtypes: {}".format( + len(shapes), len(dtypes), shapes, dtypes) + else: + dtypes = len(shapes) * ['float32'] + tensors = [] + for shape, dtype in zip(shapes, dtypes): + data = np.ones(tuple(shape)).astype(dtype) + tensors.append(paddle.to_tensor(data)) + return tensors + + +def extract_vars(inputs): + """ + Extract a list of variables from inputs. + Args: + inputs(Variable | list | dict): + """ + vars = [] + if isinstance(inputs, Variable): + vars = [inputs] + elif isinstance(inputs, dict): + for _key, _value in inputs.items(): + if isinstance(_value, Variable): + vars.append(_value) + else: + _logger.warn( + f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}." + ) + elif isinstance(inputs, (tuple, list)): + for _value in inputs: + vars.extend(extract_vars(_value)) + if len(vars) == 0: + _logger.warn(f"Extract none variables from inputs.") + return vars + + +def to_variables(inputs): + """ + Find and rename variables. Find np.ndarray and convert it to variable. + """ + if isinstance(inputs, Variable) or isinstance(inputs, np.ndarray): + return paddle.fluid.dygraph.to_variable(inputs) + elif isinstance(inputs, dict): + ret = {} + for _key in inputs: + ret[_key] = to_variables(inputs[_key]) + return inputs + elif isinstance(inputs, list): + ret = [] + for _value in inputs: + ret.append(to_variables(_value)) + return ret + + +@dygraph_only +def dygraph2program(layer, + inputs, + feed_prefix='feed_', + fetch_prefix='fetch_', + tmp_prefix='t_', + extract_inputs_fn=None, + extract_outputs_fn=None, + dtypes=None): + assert isinstance(layer, Layer) + + extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars + extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars + tracer = _dygraph_tracer()._get_program_desc_tracer() + + with program_desc_tracing_guard(True): + + if _is_shape(inputs): + shapes = [inputs] + inputs = _create_tensors(shapes, dtypes=dtypes) + input_var_list = inputs + elif _is_shapes(inputs): + inputs = _create_tensors(inputs, dtypes=dtypes) + input_var_list = inputs + else: + inputs = to_variables(inputs) + input_var_list = extract_inputs_fn(inputs) + original_outputs = layer(*inputs) + # 'original_outputs' may be dict, so we should convert it to list of varibles. + # And should not create new varibles in 'extract_vars'. + out_var_list = extract_outputs_fn(original_outputs) + program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc( + input_var_list, feed_prefix, out_var_list, fetch_prefix, tmp_prefix) + tracer.reset() + + with _dygraph_guard(None): + program = Program() + program.desc = program_desc + program.blocks = [Block(program, 0)] + program._sync_with_cpp() + return program diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index b7a9eebd..785286e3 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -397,9 +397,3 @@ class GraphWrapper(object): # Infer the remain ops in topological order. for op in head_op: recursive_infer(op, infer=True) - - def update_groups_of_conv(self): - for op in self.ops(): - if 'conv2d' in op.type() and op.attr('groups') >= op.inputs( - 'Filter')[0].shape()[0]: - op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/paddleslim/dygraph/filter_pruner.py b/paddleslim/dygraph/filter_pruner.py index 061f6c84..0cb22d0a 100644 --- a/paddleslim/dygraph/filter_pruner.py +++ b/paddleslim/dygraph/filter_pruner.py @@ -47,24 +47,34 @@ class FilterPruner(Pruner): Args: model(paddle.nn.Layer): The target model to be pruned. - input_shape(list): The input shape of model. It is used to trace the graph of the model. + inputs(list): The inputs of model. It will be use in calling 'model.forward(inputs)'. sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is set rightly, 'FilterPruner::sensitive' function can not be called anymore in next step. Default: None. """ - def __init__(self, model, input_shape, sen_file=None): - super(FilterPruner, self).__init__(model, input_shape) + def __init__(self, model, inputs, sen_file=None): + super(FilterPruner, self).__init__(model, inputs) self._status = Status(sen_file) # sensitive and var_group are just used in filter pruning - self.var_group = VarGroup(model, input_shape) + self.var_group = VarGroup(model, inputs) + + # skip vars in: + # 1. depthwise conv2d layer + self.skip_vars = [] + for sub_layer in model.sublayers(): + if isinstance( + sub_layer, + paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1: + for param in sub_layer.parameters(): + self.skip_vars.append(param.name) def sensitive(self, eval_func=None, sen_file=None, target_vars=None, - skip_vars=None): + skip_vars=[]): """ Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None". @@ -88,7 +98,7 @@ class FilterPruner(Pruner): eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None. sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None. target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None. - skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None. + skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. Default: []. Returns: dict: A dict storing sensitivities. @@ -102,6 +112,7 @@ class FilterPruner(Pruner): if not self._status.is_ckp: return self._status + skip_vars.extend(self.skip_vars) self._cal_sensitive( self.model, eval_func, @@ -186,9 +197,9 @@ class FilterPruner(Pruner): Returns: tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model. """ - base_flops = flops(self.model, self.input_shape) + base_flops = flops(self.model, self.inputs) - _logger.info("Base FLOPs: {}".format(base_flops)) + _logger.debug("Base FLOPs: {}".format(base_flops)) low = 0. up = 1.0 history = set() @@ -200,8 +211,7 @@ class FilterPruner(Pruner): if align is not None: ratios = self._round_to(ratios, dims=dims, factor=align) plan = self.prune_vars(ratios, axis=dims) - _logger.debug("pruning plan: {}".format(plan)) - c_flops = flops(self.model, self.input_shape) + c_flops = flops(self.model, self.inputs) _logger.debug("FLOPs after pruning: {}".format(c_flops)) c_pruned_flops = (base_flops - c_flops) / base_flops plan.restore(self.model) @@ -304,7 +314,11 @@ class FilterPruner(Pruner): plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'. """ - + if var_name in self.skip_vars: + _logger.warn( + f"{var_name} is skiped beacause it is not support for pruning derectly." + ) + return if isinstance(pruned_dims, int): pruned_dims = [pruned_dims] group = self.var_group.find_group(var_name, pruned_dims) @@ -315,29 +329,52 @@ class FilterPruner(Pruner): for param in sub_layer.parameters(include_sublayers=False): if param.name in group: group_dict[param.name] = group[param.name] - group_dict[param.name].update({ - 'layer': sub_layer, - 'var': param, - 'value': np.array(param.value().get_tensor()) - }) + # Varibales can be pruned on multiple axies. + for _item in group_dict[param.name]: + _item.update({ + 'layer': sub_layer, + 'var': param, + 'value': np.array(param.value().get_tensor()) + }) _logger.debug(f"set value of {param.name} into group") mask = self.cal_mask(var_name, pruned_ratio, group_dict) for _name in group_dict: - dims = group_dict[_name]['pruned_dims'] - stride = group_dict[_name]['stride'] - var_shape = group_dict[_name]['var'].shape - if isinstance(dims, int): - dims = [dims] - - current_mask = mask.repeat(stride[0]) if stride[0] > 1 else mask - - assert len(current_mask) == var_shape[dims[ - 0]], "The length of current_mask must be equal to the size of dimension to be pruned on." - - plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) + # Varibales can be pruned on multiple axies. + for _item in group_dict[_name]: + dims = _item['pruned_dims'] + transforms = _item['transforms'] + var_shape = _item['var'].shape + if isinstance(dims, int): + dims = [dims] + for trans in transforms: + mask = self._transform_mask(mask, trans) + current_mask = mask + assert len(current_mask) == var_shape[dims[ + 0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}" + plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) if apply == "lazy": plan.apply(self.model, lazy=True) elif apply == "impretive": plan.apply(self.model, lazy=False) return plan + + def _transform_mask(self, mask, transform): + src_start = transform['src_start'] + src_end = transform['src_end'] + target_start = transform['target_start'] + target_end = transform['target_end'] + target_len = transform['target_len'] + stride = transform['stride'] + mask = mask[src_start:src_end] + mask = mask.repeat(stride) if stride > 1 else mask + + dst_mask = np.ones([target_len]) + # for depthwise conv2d with: + # input shape: (1, 4, 32, 32) + # filter shape: (32, 1, 3, 3) + # groups: 4 + # if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8. + expand = int((target_end - target_start) / len(mask)) + dst_mask[target_start:target_end] = list(mask) * expand + return dst_mask diff --git a/paddleslim/dygraph/fpgm_pruner.py b/paddleslim/dygraph/fpgm_pruner.py index 8ac97b37..45bc48f9 100644 --- a/paddleslim/dygraph/fpgm_pruner.py +++ b/paddleslim/dygraph/fpgm_pruner.py @@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO) class FPGMFilterPruner(FilterPruner): - def __init__(self, model, input_shape, sen_file=None): - super(FPGMFilterPruner, self).__init__( - model, input_shape, sen_file=sen_file) + def __init__(self, model, inputs, sen_file=None): + super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file) def cal_mask(self, var_name, pruned_ratio, group): - value = group[var_name]['value'] - pruned_dims = group[var_name]['pruned_dims'] + for _item in group[var_name]: + if _item['pruned_dims'] == [0]: + value = _item['value'] + pruned_dims = _item['pruned_dims'] + assert (pruned_dims == [0]) dist_sum_list = [] diff --git a/paddleslim/dygraph/l1norm_pruner.py b/paddleslim/dygraph/l1norm_pruner.py index d7329c98..9fb2bbb8 100644 --- a/paddleslim/dygraph/l1norm_pruner.py +++ b/paddleslim/dygraph/l1norm_pruner.py @@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO) class L1NormFilterPruner(FilterPruner): - def __init__(self, model, input_shape, sen_file=None): + def __init__(self, model, inputs, sen_file=None): super(L1NormFilterPruner, self).__init__( - model, input_shape, sen_file=sen_file) + model, inputs, sen_file=sen_file) def cal_mask(self, var_name, pruned_ratio, group): - value = group[var_name]['value'] - pruned_dims = group[var_name]['pruned_dims'] + for _item in group[var_name]: + if _item['pruned_dims'] == [0]: + value = _item['value'] + pruned_dims = _item['pruned_dims'] reduce_dims = [ i for i in range(len(value.shape)) if i not in pruned_dims ] diff --git a/paddleslim/dygraph/l2norm_pruner.py b/paddleslim/dygraph/l2norm_pruner.py index 694aa935..bffdf3a2 100644 --- a/paddleslim/dygraph/l2norm_pruner.py +++ b/paddleslim/dygraph/l2norm_pruner.py @@ -12,13 +12,16 @@ _logger = get_logger(__name__, logging.INFO) class L2NormFilterPruner(FilterPruner): - def __init__(self, model, input_shape, sen_file=None): + def __init__(self, model, inputs, sen_file=None): super(L2NormFilterPruner, self).__init__( - model, input_shape, sen_file=sen_file) + model, inputs, sen_file=sen_file) def cal_mask(self, var_name, pruned_ratio, group): - value = group[var_name]['value'] - pruned_dims = group[var_name]['pruned_dims'] + # find information of pruning on output channels + for _item in group[var_name]: + if _item['pruned_dims'] == [0]: + value = _item['value'] + pruned_dims = _item['pruned_dims'] reduce_dims = [ i for i in range(len(value.shape)) if i not in pruned_dims ] diff --git a/paddleslim/dygraph/pruner.py b/paddleslim/dygraph/pruner.py index 506cbc9e..fe107e1d 100644 --- a/paddleslim/dygraph/pruner.py +++ b/paddleslim/dygraph/pruner.py @@ -19,9 +19,9 @@ class Pruner(object): """ - def __init__(self, model, input_shape): + def __init__(self, model, inputs): self.model = model - self.input_shape = input_shape + self.inputs = inputs self._var_shapes = {} for var in model.parameters(): self._var_shapes[var.name] = var.shape @@ -53,5 +53,5 @@ class Pruner(object): global_plan.apply(self.model, lazy=True) elif apply == "impretive": global_plan.apply(self.model, lazy=False) - + self.plan = global_plan return global_plan diff --git a/paddleslim/dygraph/pruning_plan.py b/paddleslim/dygraph/pruning_plan.py index 0230c286..2e35437b 100644 --- a/paddleslim/dygraph/pruning_plan.py +++ b/paddleslim/dygraph/pruning_plan.py @@ -28,7 +28,7 @@ class PruningMask(): if self._mask is not None: assert len(self._mask.shape) == len( value - ), "The length of value must be same with shape of mask in current PruningMask instance." + ), "The length of value must be same with length of mask's shape in current PruningMask instance." self._dims = list(value) @property @@ -37,11 +37,6 @@ class PruningMask(): @mask.setter def mask(self, value): - assert (isinstance(value, PruningMask)) - if self._dims is not None: - assert len(self._mask.shape) == len( - value - ), "The length of value must be same with shape of mask in current PruningMask instance." self._mask = value def __str__(self): @@ -71,13 +66,21 @@ class PruningPlan(): self._pruned_flops = value def add(self, var_name, pruning_mask): + assert (isinstance(pruning_mask, PruningMask)) if var_name not in self._masks: self._masks[var_name] = [] - self._masks[var_name].append(pruning_mask) if var_name not in self._dims: self._dims[var_name] = [] - self._dims[var_name].append(pruning_mask.dims) + + if pruning_mask.dims in self._dims[var_name]: + for _mask in self._masks[var_name]: + if pruning_mask.dims == _mask.dims: + _mask.mask = list( + np.array(_mask.mask) | np.array(pruning_mask.mask)) + else: + self._masks[var_name].append(pruning_mask) + self._dims[var_name].append(pruning_mask.dims) @property def masks(self): @@ -87,8 +90,7 @@ class PruningPlan(): assert (isinstance(plan, PruningPlan)) for var_name in plan.masks: for mask in plan.masks[var_name]: - if not self.contains(var_name, mask.dims): - self.add(var_name, mask) + self.add(var_name, mask) def contains(self, var_name, dims=None): return (var_name in self._dims) and (dims is None or @@ -172,7 +174,6 @@ class PruningPlan(): bool_mask = mask.astype(bool) pruned_value = np.apply_along_axis( lambda data: data[bool_mask], dims[0], value) - p = t_value._place() if p.is_cpu_place(): place = paddle.CPUPlace() @@ -184,14 +185,19 @@ class PruningPlan(): place = paddle.CUDAPlace(p.gpu_device_id()) t_value.set(pruned_value, place) - if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D): - if sub_layer._groups > 1 and pruned_value.shape[ - 1] == 1: # depthwise conv2d - _logger.debug( - "Update groups of depthwise conv2d form {} to {}". - format(sub_layer._groups, - pruned_value.shape[0])) - sub_layer._groups = pruned_value.shape[0] + if isinstance( + sub_layer, paddle.nn.layer.conv.Conv2D + ) and sub_layer._groups > 1 and len(param.shape) == 4: + assert param.shape[ + 1] == 1, "It just supports depthwise conv2d when groups > 1." + new_groups = int(bool_mask.sum() * + sub_layer._groups / len(bool_mask)) + _logger.debug( + "Update groups of depthwise conv2d form {} to {}". + format(sub_layer._groups, new_groups)) + sub_layer._origin_groups = sub_layer._groups + sub_layer._groups = new_groups + # for training if param.trainable: param.clear_gradient() @@ -218,11 +224,6 @@ class PruningPlan(): place = paddle.CUDAPlace(p.gpu_device_id()) t_value.set(np.array(t_backup).astype("float32"), place) - - if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D): - if sub_layer._groups > 1: - _logger.debug( - "Update groups of conv form {} to {}".format( - sub_layer._groups, t_value.shape()[0])) - sub_layer._groups = t_value.shape()[0] + if "_origin_groups" in sub_layer.__dict__: + sub_layer._groups = sub_layer._origin_groups del sub_layer._buffers[backup_name] diff --git a/paddleslim/dygraph/var_group.py b/paddleslim/dygraph/var_group.py index 6e86f847..6ee02d7d 100644 --- a/paddleslim/dygraph/var_group.py +++ b/paddleslim/dygraph/var_group.py @@ -2,7 +2,7 @@ import numpy as np import logging import paddle from paddle.fluid.dygraph import TracedLayer -from ..core import GraphWrapper +from ..core import GraphWrapper, dygraph2program from ..prune import collect_convs from ..common import get_logger @@ -12,33 +12,43 @@ _logger = get_logger(__name__, level=logging.INFO) class VarGroup(): - def __init__(self, model, input_shape): + """ + A tool used to parse dygraph and store information of variables' relationship. + Args: + - model(nn.Layer): The dygraph to be parsed. + - inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`. + """ + + def __init__(self, model, inputs): self.groups = [] - self._parse_model(model, input_shape) + self._parse_model(model, inputs) def _to_dict(self, group): ret = {} - for _name, _axis, _stride in group: + for _name, _axis, _transforms in group: if isinstance(_axis, int): - _axis = [_axis] # TODO: fix - ret[_name] = {'pruned_dims': _axis, 'stride': _stride} + _axis = [_axis] + if _name not in ret: + ret[_name] = [] + # Variable can be pruned on multiple axies. + ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms}) return ret def find_group(self, var_name, axis): for group in self.groups: for _name, _axis, _stride in group: if isinstance(_axis, int): - _axis = [_axis] # TODO: fix + _axis = [_axis] if _name == var_name and _axis == axis: return self._to_dict(group) - def _parse_model(self, model, input_shape): - _logger.debug("Parsing model with input: {}".format(input_shape)) - data = np.ones(tuple(input_shape)).astype("float32") - in_var = paddle.to_tensor(data) + def _parse_model(self, model, inputs): + _logger.debug("Parsing model with input: {}".format(inputs)) + model.eval() - out_dygraph, static_layer = TracedLayer.trace(model, inputs=[in_var]) - graph = GraphWrapper(static_layer.program) + program = dygraph2program(model, inputs=inputs) + + graph = GraphWrapper(program) visited = {} for name, param in model.named_parameters(): diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index 1e94df5d..f8f8db2a 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -79,7 +79,7 @@ def collect_convs(params, graph, visited={}): pruned_params=pruned_params, visited=visited) - walker.prune(param, pruned_axis=0, pruned_idx=[0]) + walker.prune(param, pruned_axis=0, pruned_idx=[]) groups.append(pruned_params) visited = set() uniq_groups = [] @@ -96,5 +96,4 @@ def collect_convs(params, graph, visited={}): simple_group.append((param, axis, pruned_idx)) if not repeat_group: uniq_groups.append(simple_group) - return uniq_groups diff --git a/paddleslim/prune/idx_selector.py b/paddleslim/prune/idx_selector.py index 73caf276..57f21383 100644 --- a/paddleslim/prune/idx_selector.py +++ b/paddleslim/prune/idx_selector.py @@ -59,11 +59,9 @@ def default_idx_selector(group, ratio): pruned_num = int(round(len(sorted_idx) * ratio)) pruned_idx = sorted_idx[:pruned_num] - idxs = [] - for name, axis, score, offsets in group: - r_idx = [i + offsets[0] for i in pruned_idx] - idxs.append((name, axis, r_idx)) + for name, axis, score, transforms in group: + idxs.append((name, axis, pruned_idx, transforms)) return idxs @@ -112,6 +110,6 @@ def optimal_threshold(group, ratio): pruned_idx = np.squeeze(np.argwhere(score < th)) idxs = [] - for name, axis, score, _ in group: - idxs.append((name, axis, pruned_idx)) + for name, axis, score, transforms in group: + idxs.append((name, axis, pruned_idx, transforms)) return idxs diff --git a/paddleslim/prune/prune_io.py b/paddleslim/prune/prune_io.py index af7303e0..e9e12365 100644 --- a/paddleslim/prune/prune_io.py +++ b/paddleslim/prune/prune_io.py @@ -10,6 +10,7 @@ __all__ = ["save_model", "load_model"] _logger = get_logger(__name__, level=logging.INFO) _SHAPES_FILE = "__shapes__" +_GROUPS_FILE = "__groups__" def save_model(exe, graph, dirname): @@ -39,6 +40,17 @@ def save_model(exe, graph, dirname): json.dump(shapes, f) _logger.info("Save shapes of weights into {}".format(SHAPES_FILE)) + groups = {} + for op in graph.ops(): + if 'conv2d' in op.type(): + filter_name = op.inputs('Filter')[0].name() + groups[filter_name] = op.attr('groups') + + GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE) + with open(GROUPS_FILE, "w") as f: + json.dump(groups, f) + _logger.info("Save groups of cnov2d into {}".format(GROUPS_FILE)) + def load_model(exe, graph, dirname): """ @@ -53,7 +65,6 @@ def load_model(exe, graph, dirname): paddle.static.Program) else graph SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE) - _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) with open(SHAPES_FILE, "r") as f: shapes = json.load(f) for param_name, shape in shapes.items(): @@ -62,9 +73,17 @@ def load_model(exe, graph, dirname): param.set_shape(shape) else: _logger.info('{} is not loaded'.format(param_name)) - _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) + + GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE) + with open(GROUPS_FILE, "r") as f: + groups = json.load(f) + for op in graph.ops(): + if 'conv2d' in op.type(): + filter_name = op.inputs('Filter')[0].name() + op.set_attr('groups', groups[filter_name]) + _logger.info("Load groups of conv2d from {}".format(GROUPS_FILE)) + paddle.static.load(program=graph.program, model_path=dirname, executor=exe) - graph.update_groups_of_conv() graph.infer_shape() _logger.info("Load weights from {}".format(dirname)) diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 158c8aa6..59b8b5ac 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -65,6 +65,15 @@ class PruneWorker(object): self.visited[pruned_axis][key] = True return True + def _visit_and_search(self, var, axis, transforms): + self._visit(var, axis) + pre_ops = var.inputs() + for op in pre_ops: + self._prune_op(op, var, axis, transforms) + next_ops = var.outputs() + for op in next_ops: + self._prune_op(op, var, axis, transforms) + def _prune(self, var, pruned_axis, pruned_idx): raise NotImplementedError('Abstract method.') @@ -85,6 +94,9 @@ class PruneWorker(object): cls = PRUNE_WORKER.get("default_walker") _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( self.op, op, pruned_axis, var.name())) + _logger.debug( + f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n" + ) walker = cls(op, pruned_params=self.pruned_params, visited=self.visited) walker.prune(var, pruned_axis, pruned_idx) @@ -170,11 +182,6 @@ class conv2d(PruneWorker): self.pruned_params.append( (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) - output_var = self.op.outputs("Output")[0] - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) - @PRUNE_WORKER.register class conv2d_transpose(PruneWorker): @@ -250,6 +257,12 @@ class batch_norm(PruneWorker): self._prune_op(op, out_var, pruned_axis, pruned_idx) +@PRUNE_WORKER.register +class sync_batch_norm(batch_norm): + def __init__(self, op, pruned_params, visited): + super(sync_batch_norm, self).__init__(op, pruned_params, visited) + + class elementwise_op(PruneWorker): def __init__(self, op, pruned_params, visited): super(elementwise_op, self).__init__(op, pruned_params, visited) @@ -269,9 +282,12 @@ class elementwise_op(PruneWorker): in_var = self.op.inputs(name)[0] if len(in_var.shape()) == 1 and in_var.shape()[0] == 1: continue - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, actual_axis, pruned_idx) + + # for bias + if name == "Y" and actual_axis >= 0 and not ( + len(in_var.shape()) == 1 and in_var.shape()[0] == 1): + self.pruned_params.append((in_var, actual_axis, pruned_idx)) + self._visit_and_search(in_var, actual_axis, pruned_idx) else: if var in self.op.inputs("X"): @@ -287,24 +303,17 @@ class elementwise_op(PruneWorker): in_var.shape()[0] == 1): self.pruned_params.append( (in_var, y_pruned_axis, pruned_idx)) - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, y_pruned_axis, pruned_idx) + self._visit_and_search(in_var, y_pruned_axis, pruned_idx) elif var in self.op.inputs("Y"): in_var = self.op.inputs("X")[0] if len(in_var.shape()) != len(var.shape()): assert (len(var.shape()) < len(in_var.shape())) pruned_axis = pruned_axis + axis if pruned_axis <= len(in_var.shape()): - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) - out_var = self.op.outputs("Out")[0] - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + out_var = self.op.outputs("Out")[0] + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -447,53 +456,70 @@ class concat(PruneWorker): def __init__(self, op, pruned_params, visited): super(concat, self).__init__(op, pruned_params, visited) - def _prune(self, var, pruned_axis, pruned_idx): - idx = [] + def _prune(self, var, pruned_axis, transforms): axis = self.op.attr("axis") if var in self.op.outputs("Out"): + self._visit(var, pruned_axis) start = 0 if axis == pruned_axis: for _, in_var in enumerate(self.op.inputs("X")): idx = [] - for i in pruned_idx: - r_idx = i - start - if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0: - idx.append(r_idx) + transoform = { + 'src_start': start, + 'src_end': start + in_var.shape()[pruned_axis], + 'target_start': 0, + 'target_end': in_var.shape()[pruned_axis], + 'target_len': in_var.shape()[pruned_axis], + 'stride': 1 + } start += in_var.shape()[pruned_axis] + self._visit(in_var, pruned_axis) pre_ops = in_var.inputs() for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, idx) - idx = pruned_idx[:] + self._prune_op(op, in_var, pruned_axis, + transforms + [transoform]) else: for _, in_var in enumerate(self.op.inputs("X")): + self._visit(in_var, pruned_axis) pre_ops = in_var.inputs() for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) + self._prune_op(op, in_var, pruned_axis, transforms) elif var in self.op.inputs("X"): + self._visit(var, pruned_axis) if axis == pruned_axis: idx = [] - start = 0 + target_start = 0 for v in self.op.inputs("X"): - if v.name() == var.name(): - idx = [i + start for i in pruned_idx] + if v.name() != var.name(): + target_start += v.shape()[pruned_axis] else: - start += v.shape()[pruned_axis] - + break + target_end = target_start + v.shape()[pruned_axis] out_var = self.op.outputs("Out")[0] - self._visit(out_var, pruned_axis) next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, idx, visited={}) - else: - for v in self.op.inputs("X"): - for op in v.inputs(): - self._prune_op(op, v, pruned_axis, pruned_idx) - out_var = self.op.outputs("Out")[0] + + transform = { + 'src_start': 0, + 'src_end': var.shape()[pruned_axis], + 'target_start': target_start, + 'target_end': target_end, + 'target_len': out_var.shape()[pruned_axis], + 'stride': 1 + } + self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + # The output of concat can be visited repeatedly + c_visited = {} + self._prune_op( + op, + out_var, + pruned_axis, + transforms + [transform], + visited=c_visited) + # Add nodes searched from concat into global visited array. + self.visited.update(c_visited) @PRUNE_WORKER.register @@ -501,8 +527,14 @@ class depthwise_conv2d(PruneWorker): def __init__(self, op, pruned_params, visited={}): super(depthwise_conv2d, self).__init__(op, pruned_params, visited) - def _prune(self, var, pruned_axis, pruned_idx): + def _prune(self, var, pruned_axis, transforms): + assert var not in self.op.inputs( + "Filter"), "Unsupport for pruning depthwise conv2d directly." + assert var not in self.op.outputs( + "Output" + ), "Unsupport for pruning output of depthwise conv2d directly." data_format = self.op.attr("data_format") + groups = self.op.attr("groups") channel_axis = 1 if data_format == "NHWC": channel_axis = 3 @@ -510,60 +542,28 @@ class depthwise_conv2d(PruneWorker): assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( pruned_axis) + groups = var.shape()[channel_axis] filter_var = self.op.inputs("Filter")[0] - self.pruned_params.append((filter_var, 0, pruned_idx)) + transform = { + 'src_start': 0, + 'src_end': var.shape()[pruned_axis], + 'target_start': 0, + 'target_end': filter_var.shape()[0], + 'target_len': filter_var.shape()[0], + 'stride': 1 + } + + self.pruned_params.append((filter_var, 0, transforms + [transform])) self._visit(filter_var, 0) for op in filter_var.outputs(): - self._prune_op(op, filter_var, 0, pruned_idx) + self._prune_op(op, filter_var, 0, transforms + [transform]) output_var = self.op.outputs("Output")[0] next_ops = output_var.outputs() for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) - - elif var in self.op.inputs("Filter"): - assert pruned_axis in [0] - if pruned_axis == 0: - if len(self.op.inputs("Bias")) > 0: - self.pruned_params.append( - (self.op.inputs("Bias"), channel_axis, pruned_idx)) - - self.pruned_params.append((var, 0, pruned_idx)) - - for op in var.outputs(): - self._prune_op(op, var, 0, pruned_idx) - - output_var = self.op.outputs("Output")[0] - self._visit(output_var, channel_axis) - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) - for op in var.outputs(): - self._prune_op(op, var, pruned_axis, pruned_idx) - elif var in self.op.outputs("Output"): - assert pruned_axis == channel_axis - filter_var = self.op.inputs("Filter")[0] - self.pruned_params.append((filter_var, 0, pruned_idx)) - self._visit(filter_var, 0) - - for op in filter_var.outputs(): - self._prune_op(op, filter_var, 0, pruned_idx) - - if len(self.op.inputs("Bias")) > 0: - self.pruned_params.append( - (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) - - in_var = self.op.inputs("Input")[0] - self._visit(in_var, channel_axis) - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, channel_axis, pruned_idx) - - output_var = self.op.outputs("Output")[0] - next_ops = output_var.outputs() - for op in next_ops: - self._prune_op(op, output_var, channel_axis, pruned_idx) + self._prune_op(op, output_var, channel_axis, + transforms + [transform]) @PRUNE_WORKER.register @@ -679,7 +679,7 @@ class flatten_contiguous_range(PruneWorker): super(flatten_contiguous_range, self).__init__(op, pruned_params, visited) - def _prune(self, var, pruned_axis, pruned_idx): + def _prune(self, var, pruned_axis, transforms): start_axis = self.op.attr("start_axis") stop_axis = self.op.attr("stop_axis") if var in self.op.inputs("X"): @@ -687,7 +687,6 @@ class flatten_contiguous_range(PruneWorker): in_var = self.op.inputs("X")[0] stride = 1 out_pruned_axis = pruned_axis - out_pruned_idx = pruned_idx if pruned_axis >= start_axis and pruned_axis <= stop_axis: out_pruned_axis = start_axis for i in range(pruned_axis + 1, stop_axis + 1): @@ -697,7 +696,8 @@ class flatten_contiguous_range(PruneWorker): self._visit(in_var, pruned_axis) self._visit(out_var, out_pruned_axis) - + transform = {'stride': stride} next_ops = out_var.outputs() for op in next_ops: - self._prune_op(op, out_var, out_pruned_axis, [stride]) + self._prune_op(op, out_var, out_pruned_axis, + transforms + [transform]) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index d3b87c86..9ebcc5ee 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -85,8 +85,8 @@ class Pruner(): param_backup = {} if param_backup else None param_shape_backup = {} if param_shape_backup else None - visited = {} pruned_params = [] + visited = {} for param, ratio in zip(params, ratios): _logger.info("pruning: {}".format(param)) if graph.var(param) is None: @@ -98,28 +98,19 @@ class Pruner(): visited)[0] # [(name, axis, pruned_idx)] if group is None or len(group) == 0: continue - if only_graph and self.idx_selector.__name__ == "default_idx_selector": - - param_v = graph.var(param) - pruned_num = int(round(param_v.shape()[0] * ratio)) - pruned_idx = [0] * pruned_num - for name, axis, _ in group: - pruned_params.append((name, axis, pruned_idx)) - - else: - assert ((not self.pruned_weights), - "The weights have been pruned once.") - group_values = [] - for name, axis, pruned_idx in group: - var = scope.find_var(name) - if var is not None: - values = np.array(var.get_tensor()) - group_values.append((name, values, axis, pruned_idx)) - - scores = self.criterion( - group_values, graph) # [(name, axis, score, pruned_idx)] - - pruned_params.extend(self.idx_selector(scores, ratio)) + assert ((not self.pruned_weights), + "The weights have been pruned once.") + group_values = [] + for name, axis, pruned_idx in group: + var = scope.find_var(name) + if var is not None: + values = np.array(var.get_tensor()) + group_values.append((name, values, axis, pruned_idx)) + + scores = self.criterion(group_values, + graph) # [(name, axis, score, pruned_idx)] + g = self._transform(self.idx_selector(scores, ratio)) + pruned_params.extend(g) merge_pruned_params = {} for param, pruned_axis, pruned_idx in pruned_params: @@ -128,7 +119,6 @@ class Pruner(): if pruned_axis not in merge_pruned_params[param]: merge_pruned_params[param][pruned_axis] = [] merge_pruned_params[param][pruned_axis].append(pruned_idx) - for param_name in merge_pruned_params: for pruned_axis in merge_pruned_params[param_name]: pruned_idx = np.concatenate(merge_pruned_params[param_name][ @@ -138,12 +128,26 @@ class Pruner(): _logger.debug("{}\t{}\t{}\t{}".format( param.name(), pruned_axis, param.shape()[pruned_axis], len(pruned_idx))) + origin_shape = copy.deepcopy(param.shape()) if param_shape_backup is not None: - origin_shape = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = origin_shape new_shape = list(param.shape()) new_shape[pruned_axis] -= len(pruned_idx) param.set_shape(new_shape) + # update groups of depthwise conv2d + for op in param.outputs(): + if op.type() in ["conv2d", "depthwise_conv2d" + ] and op.attr("groups") > 1: + assert origin_shape[ + 1] == 1, "Only support for depthwise when groups > 1." + new_groups = int( + op.attr("groups") * new_shape[pruned_axis] / + origin_shape[pruned_axis]) + _logger.debug( + f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}" + ) + op.set_attr("groups", new_groups) + if not only_graph: param_t = scope.find_var(param.name()).get_tensor() if param_backup is not None and ( @@ -156,16 +160,35 @@ class Pruner(): pruned_idx, pruned_axis=pruned_axis, lazy=lazy) + param_t.set(pruned_param, place) except IndexError as e: _logger.error("Pruning {}, but get [{}]".format( param.name(), e)) - param_t.set(pruned_param, place) - graph.update_groups_of_conv() graph.infer_shape() self.pruned_weights = (not only_graph) return graph.program, param_backup, param_shape_backup + def _transform(self, group): + ret = [] + for name, axis, pruned_idx, transforms in group: + src = pruned_idx + for trans in transforms: + src_start = trans['src_start'] + src_end = trans['src_end'] + target_start = trans['target_start'] + target_end = trans['target_end'] + target = [] + for idx in src: + if idx >= src_start and idx < src_end: + idx -= src_start + idx += target_start + if idx < target_end: + target.append(idx) + src = target + ret.append((name, axis, src)) + return ret + def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): """ Pruning a array by indexes on given axis. diff --git a/tests/dygraph/test_flops.py b/tests/dygraph/test_flops.py index fbddc2fd..01ffc451 100644 --- a/tests/dygraph/test_flops.py +++ b/tests/dygraph/test_flops.py @@ -1,8 +1,11 @@ import sys sys.path.append("../../") import unittest +import numpy as np +import paddle from paddleslim.analysis import dygraph_flops as flops from paddle.vision.models import mobilenet_v1, resnet50 +from paddle.nn import Conv2D, Layer class TestFlops(unittest.TestCase): @@ -17,9 +20,64 @@ class TestFlops(unittest.TestCase): self.assertTrue(FLOPs == self._gt) +class Net1(Layer): + def __init__(self): + super(Net1, self).__init__() + self.conv1 = Conv2D(3, 2, 3) + self.conv2 = Conv2D(3, 2, 3) + + def forward(self, inputs): + assert isinstance(inputs, dict) + x = inputs["x"] + y = inputs["y"] + return {"x": self.conv1(x), "y": self.conv2(y), "dummy": "dummy"} + + +class Net2(Net1): + def __init__(self): + super(Net2, self).__init__() + + def forward(self, x, y): + return [self.conv1(x), self.conv2(y), "dummy"] + + +class TestFLOPsCase1(unittest.TestCase): + def runTest(self): + x_shape = (1, 3, 32, 32) + y_shape = (1, 3, 16, 16) + net = Net1() + x = np.random.uniform(-1, 1, x_shape).astype('float32') + y = np.random.uniform(-1, 1, y_shape).astype('float32') + + inputs = { + "x": paddle.to_tensor(x), + "y": paddle.to_tensor(y), + "z": "test" + } + FLOPs = flops(net, [inputs]) + self.assertTrue(FLOPs == 59184) + + +class TestFLOPsCase2(unittest.TestCase): + def runTest(self): + x_shape = (1, 3, 32, 32) + y_shape = (1, 3, 16, 16) + net = Net2() + x = np.random.uniform(-1, 1, x_shape).astype('float32') + y = np.random.uniform(-1, 1, y_shape).astype('float32') + + inputs = [paddle.to_tensor(x), paddle.to_tensor(y)] + FLOPs1 = flops(net, inputs) + shapes = [x_shape, y_shape] + FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"]) + self.assertTrue(FLOPs1 == FLOPs2) + + def add_cases(suite): suite.addTest(TestFlops(net=mobilenet_v1, gt=11792896.0)) suite.addTest(TestFlops(net=resnet50, gt=83872768.0)) + suite.addTest(TestFLOPsCase1()) + suite.addTest(TestFLOPsCase2()) def load_tests(loader, standard_tests, pattern): diff --git a/tests/dygraph/test_prune.py b/tests/dygraph/test_prune.py index acd691f2..64a5b788 100644 --- a/tests/dygraph/test_prune.py +++ b/tests/dygraph/test_prune.py @@ -47,6 +47,7 @@ class TestPrune(unittest.TestCase): shapes = {} for param in model.parameters(): shapes[param.name] = param.shape + pruner.restore() return shapes def static_prune(self, net, ratios): diff --git a/tests/layers.py b/tests/layers.py index a5f0b37d..e7d4f5f2 100644 --- a/tests/layers.py +++ b/tests/layers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr @@ -23,7 +24,8 @@ def conv_bn_layer(input, groups=1, act=None, bias=False, - use_cudnn=True): + use_cudnn=True, + sync_bn=False): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -37,11 +39,19 @@ def conv_bn_layer(input, name=name + "_out", use_cudnn=use_cudnn) bn_name = name + "_bn" - return fluid.layers.batch_norm( - input=conv, - act=act, - name=bn_name + '_output', - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance', ) + if sync_bn: + bn = paddle.nn.SyncBatchNorm( + num_filters, + weight_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + name=bn_name) + return bn(conv) + else: + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '_output', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py new file mode 100644 index 00000000..b4c97afe --- /dev/null +++ b/tests/test_dygraph_pruning_plan.py @@ -0,0 +1,22 @@ +import sys +sys.path.append("../") +import unittest +import numpy as np +from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask + + +class TestPruningPlan(unittest.TestCase): + def testAdd(self): + plan = PruningPlan() + mask = PruningMask([0], [0, 0, 1], 0.33) + plan.add("a", mask) + mask = PruningMask([0], [0, 1, 0], 0.33) + plan.add("a", mask) + a_mask = plan.masks["a"] + self.assertTrue(len(a_mask) == 1) + self.assertTrue(a_mask[0].mask == [0, 1, 1]) + self.assertTrue(a_mask[0].dims == [0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_group_param.py b/tests/test_group_param.py index 0ffefde6..9e1c9f6f 100644 --- a/tests/test_group_param.py +++ b/tests/test_group_param.py @@ -41,14 +41,30 @@ class TestPrune(StaticCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") - groups = collect_convs( + collected_groups = collect_convs( ["conv1_weights", "conv2_weights", "conv3_weights"], main_program) - while [] in groups: - groups.remove([]) - print(groups) - self.assertTrue(len(groups) == 2) - self.assertTrue(len(groups[0]) == 20) - self.assertTrue(len(groups[1]) == 6) + while [] in collected_groups: + collected_groups.remove([]) + print(collected_groups) + + params = set([ + param.name for param in main_program.all_parameters() + if "weights" in param.name + ]) + + expected_groups = [[('conv1_weights', 0), ('conv2_weights', 1), + ('conv2_weights', 0), ('conv3_weights', 1), + ('conv4_weights', 0), ('conv5_weights', 1)], + [('conv3_weights', 0), ('conv4_weights', 1)]] + + self.assertTrue(len(collected_groups) == len(expected_groups)) + for _collected, _expected in zip(collected_groups, expected_groups): + for _name, _axis, _ in _collected: + if _name in params: + self.assertTrue((_name, _axis) in _expected) + for _name, _axis in _expected: + if _name in params: + self.assertTrue((_name, _axis, []) in _collected) if __name__ == '__main__': diff --git a/tests/test_prune_op.py b/tests/test_prune_op.py new file mode 100644 index 00000000..a385704a --- /dev/null +++ b/tests/test_prune_op.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +from static_case import StaticCase +import paddle.fluid as fluid +from paddleslim.prune import Pruner +from static_case import StaticCase +from layers import conv_bn_layer + + +class TestPrune(StaticCase): + def test_concat(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X + # conv1 conv2-->concat conv3-->sum-->out + # | ^ | ^ + # |____________| |____________________| + # + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(input, 8, 3, "conv2", sync_bn=True) + tmp = fluid.layers.concat([conv1, conv2], axis=1) + conv3 = conv_bn_layer(input, 16, 3, "conv3", bias=None) + out = conv3 + tmp + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + pruner = Pruner() + # test backward search of concat + pruned_program, _, _ = pruner.prune( + main_program, + scope, + params=["conv3_weights"], + ratios=[0.5], + place=place, + lazy=False, + only_graph=True, + param_backup=None, + param_shape_backup=None) + shapes = { + "conv3_weights": (8, 3, 3, 3), + "conv2_weights": (4, 3, 3, 3), + "conv1_weights": (4, 3, 3, 3) + } + for param in pruned_program.global_block().all_parameters(): + if "weights" in param.name and "conv2d" in param.name: + self.assertTrue(shapes[param.name] == param.shape) + + # test forward search of concat + pruned_program, _, _ = pruner.prune( + main_program, + scope, + params=["conv1_weights", "conv2_weights"], + ratios=[0.5, 0.5], + place=place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None) + + shapes = { + "conv1_weights": (4, 3, 3, 3), + "conv1_bn_scale": (4, ), + "conv1_bn_variance": (4, ), + "conv1_bn_mean": (4, ), + "conv1_bn_offset": (4, ), + "conv2_weights": (4, 3, 3, 3), + "sync_batch_norm_0.w_0": (4, ), + "sync_batch_norm_0.w_1": (4, ), + "conv2_bn_scale": (4, ), + "conv2_bn_offset": (4, ), + "conv3_weights": (8, 3, 3, 3), + "conv3_bn_mean": (8, ), + "conv3_bn_offset": (8, ), + "conv3_bn_scale": (8, ), + "conv3_bn_variance": (8, ), + "conv3_out.b_0": (8, ), + } + + for param in pruned_program.global_block().all_parameters(): + if "weights" in param.name and "conv2d" in param.name: + self.assertTrue(shapes[param.name] == param.shape) + + +if __name__ == '__main__': + unittest.main() -- GitLab