未验证 提交 e3226b49 编写于 作者: W whs 提交者: GitHub

Refine pruning code for detection (#554) (#600)

上级 dbb981d3
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import paddle import paddle
import numpy as np import numpy as np
import paddle.jit as jit import paddle.jit as jit
from ..core import GraphWrapper from ..core import GraphWrapper, dygraph2program
__all__ = ["flops", "dygraph_flops"] __all__ = ["flops", "dygraph_flops"]
...@@ -83,11 +83,27 @@ def _graph_flops(graph, only_conv=True, detail=False): ...@@ -83,11 +83,27 @@ def _graph_flops(graph, only_conv=True, detail=False):
return flops return flops
def dygraph_flops(model, input_shape, only_conv=False, detail=False): def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
"""
Compute the FLOPs of nn.Layer.
Args:
model(nn.Layer): The target model.
inputs(list): The dummy inputs used for 'model.forward'. It can be:
1. list<int>|tuple<int>: means 'model.forward' accepts
only one variable as argument and the shape of
variable is 'inputs'.
2. list<list<list>>: means 'model.forward' accepts multiple
variables as arguments and the shapes of variables is 'inputs'.
3. others: 'inputs' will be used as argument list by calling
'model.forward(*inputs)'.
dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
data type of each input. None means all the inputs is 'float32'.
Default: None.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
"""
data = np.ones(tuple(input_shape)).astype("float32") program = dygraph2program(model, inputs)
in_var = paddle.to_tensor(data)
_, traced = paddle.jit.TracedLayer.trace(model, [in_var])
program = traced.program
graph = GraphWrapper(program) graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail) return _graph_flops(graph, only_conv=only_conv, detail=detail)
...@@ -12,7 +12,14 @@ ...@@ -12,7 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper from ..core import graph_wrapper
from .registry import Registry from .graph_wrapper import *
from ..core import registry
from .registry import *
from ..core import dygraph
from .dygraph import *
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry'] __all__ = []
__all__ += graph_wrapper.__all__
__all__ += registry.__all__
__all__ += dygraph.__all__
import paddle
import collections
import logging
import numpy as np
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard
from paddle.fluid.dygraph.base import program_desc_tracing_guard
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from ..common import get_logger
__all__ = ["dygraph2program"]
_logger = get_logger(__name__, level=logging.INFO)
def _is_shape(values):
if not isinstance(values, (list, tuple)):
return False
for v in values:
if not isinstance(v, int):
return False
return True
def _is_shapes(values):
if not isinstance(values, (list, tuple)):
return False
for v in values:
if not _is_shape(v):
return False
return True
def _create_tensors(shapes, dtypes=None):
if dtypes is not None:
assert len(shapes) == len(
dtypes
), "Length of shapes and dtypes must be same. But get len(shapes): {}; len(dtypes): {}; shapes: {}; dtypes: {}".format(
len(shapes), len(dtypes), shapes, dtypes)
else:
dtypes = len(shapes) * ['float32']
tensors = []
for shape, dtype in zip(shapes, dtypes):
data = np.ones(tuple(shape)).astype(dtype)
tensors.append(paddle.to_tensor(data))
return tensors
def extract_vars(inputs):
"""
Extract a list of variables from inputs.
Args:
inputs(Variable | list<Object> | dict):
"""
vars = []
if isinstance(inputs, Variable):
vars = [inputs]
elif isinstance(inputs, dict):
for _key, _value in inputs.items():
if isinstance(_value, Variable):
vars.append(_value)
else:
_logger.warn(
f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}."
)
elif isinstance(inputs, (tuple, list)):
for _value in inputs:
vars.extend(extract_vars(_value))
if len(vars) == 0:
_logger.warn(f"Extract none variables from inputs.")
return vars
def to_variables(inputs):
"""
Find and rename variables. Find np.ndarray and convert it to variable.
"""
if isinstance(inputs, Variable) or isinstance(inputs, np.ndarray):
return paddle.fluid.dygraph.to_variable(inputs)
elif isinstance(inputs, dict):
ret = {}
for _key in inputs:
ret[_key] = to_variables(inputs[_key])
return inputs
elif isinstance(inputs, list):
ret = []
for _value in inputs:
ret.append(to_variables(_value))
return ret
@dygraph_only
def dygraph2program(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_',
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
tracer = _dygraph_tracer()._get_program_desc_tracer()
with program_desc_tracing_guard(True):
if _is_shape(inputs):
shapes = [inputs]
inputs = _create_tensors(shapes, dtypes=dtypes)
input_var_list = inputs
elif _is_shapes(inputs):
inputs = _create_tensors(inputs, dtypes=dtypes)
input_var_list = inputs
else:
inputs = to_variables(inputs)
input_var_list = extract_inputs_fn(inputs)
original_outputs = layer(*inputs)
# 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'.
out_var_list = extract_outputs_fn(original_outputs)
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
input_var_list, feed_prefix, out_var_list, fetch_prefix, tmp_prefix)
tracer.reset()
with _dygraph_guard(None):
program = Program()
program.desc = program_desc
program.blocks = [Block(program, 0)]
program._sync_with_cpp()
return program
...@@ -397,9 +397,3 @@ class GraphWrapper(object): ...@@ -397,9 +397,3 @@ class GraphWrapper(object):
# Infer the remain ops in topological order. # Infer the remain ops in topological order.
for op in head_op: for op in head_op:
recursive_infer(op, infer=True) recursive_infer(op, infer=True)
def update_groups_of_conv(self):
for op in self.ops():
if 'conv2d' in op.type() and op.attr('groups') >= op.inputs(
'Filter')[0].shape()[0]:
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
...@@ -47,24 +47,34 @@ class FilterPruner(Pruner): ...@@ -47,24 +47,34 @@ class FilterPruner(Pruner):
Args: Args:
model(paddle.nn.Layer): The target model to be pruned. model(paddle.nn.Layer): The target model to be pruned.
input_shape(list<int>): The input shape of model. It is used to trace the graph of the model. inputs(list<int>): The inputs of model. It will be use in calling 'model.forward(inputs)'.
sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is
set rightly, 'FilterPruner::sensitive' function can not be called anymore set rightly, 'FilterPruner::sensitive' function can not be called anymore
in next step. Default: None. in next step. Default: None.
""" """
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(FilterPruner, self).__init__(model, input_shape) super(FilterPruner, self).__init__(model, inputs)
self._status = Status(sen_file) self._status = Status(sen_file)
# sensitive and var_group are just used in filter pruning # sensitive and var_group are just used in filter pruning
self.var_group = VarGroup(model, input_shape) self.var_group = VarGroup(model, inputs)
# skip vars in:
# 1. depthwise conv2d layer
self.skip_vars = []
for sub_layer in model.sublayers():
if isinstance(
sub_layer,
paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1:
for param in sub_layer.parameters():
self.skip_vars.append(param.name)
def sensitive(self, def sensitive(self,
eval_func=None, eval_func=None,
sen_file=None, sen_file=None,
target_vars=None, target_vars=None,
skip_vars=None): skip_vars=[]):
""" """
Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None". Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None".
...@@ -88,7 +98,7 @@ class FilterPruner(Pruner): ...@@ -88,7 +98,7 @@ class FilterPruner(Pruner):
eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None. eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None.
sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None. sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None.
target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None. target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None. skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. Default: [].
Returns: Returns:
dict: A dict storing sensitivities. dict: A dict storing sensitivities.
...@@ -102,6 +112,7 @@ class FilterPruner(Pruner): ...@@ -102,6 +112,7 @@ class FilterPruner(Pruner):
if not self._status.is_ckp: if not self._status.is_ckp:
return self._status return self._status
skip_vars.extend(self.skip_vars)
self._cal_sensitive( self._cal_sensitive(
self.model, self.model,
eval_func, eval_func,
...@@ -186,9 +197,9 @@ class FilterPruner(Pruner): ...@@ -186,9 +197,9 @@ class FilterPruner(Pruner):
Returns: Returns:
tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model. tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model.
""" """
base_flops = flops(self.model, self.input_shape) base_flops = flops(self.model, self.inputs)
_logger.info("Base FLOPs: {}".format(base_flops)) _logger.debug("Base FLOPs: {}".format(base_flops))
low = 0. low = 0.
up = 1.0 up = 1.0
history = set() history = set()
...@@ -200,8 +211,7 @@ class FilterPruner(Pruner): ...@@ -200,8 +211,7 @@ class FilterPruner(Pruner):
if align is not None: if align is not None:
ratios = self._round_to(ratios, dims=dims, factor=align) ratios = self._round_to(ratios, dims=dims, factor=align)
plan = self.prune_vars(ratios, axis=dims) plan = self.prune_vars(ratios, axis=dims)
_logger.debug("pruning plan: {}".format(plan)) c_flops = flops(self.model, self.inputs)
c_flops = flops(self.model, self.input_shape)
_logger.debug("FLOPs after pruning: {}".format(c_flops)) _logger.debug("FLOPs after pruning: {}".format(c_flops))
c_pruned_flops = (base_flops - c_flops) / base_flops c_pruned_flops = (base_flops - c_flops) / base_flops
plan.restore(self.model) plan.restore(self.model)
...@@ -304,7 +314,11 @@ class FilterPruner(Pruner): ...@@ -304,7 +314,11 @@ class FilterPruner(Pruner):
plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'. plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'.
""" """
if var_name in self.skip_vars:
_logger.warn(
f"{var_name} is skiped beacause it is not support for pruning derectly."
)
return
if isinstance(pruned_dims, int): if isinstance(pruned_dims, int):
pruned_dims = [pruned_dims] pruned_dims = [pruned_dims]
group = self.var_group.find_group(var_name, pruned_dims) group = self.var_group.find_group(var_name, pruned_dims)
...@@ -315,7 +329,9 @@ class FilterPruner(Pruner): ...@@ -315,7 +329,9 @@ class FilterPruner(Pruner):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in group: if param.name in group:
group_dict[param.name] = group[param.name] group_dict[param.name] = group[param.name]
group_dict[param.name].update({ # Varibales can be pruned on multiple axies.
for _item in group_dict[param.name]:
_item.update({
'layer': sub_layer, 'layer': sub_layer,
'var': param, 'var': param,
'value': np.array(param.value().get_tensor()) 'value': np.array(param.value().get_tensor())
...@@ -324,20 +340,41 @@ class FilterPruner(Pruner): ...@@ -324,20 +340,41 @@ class FilterPruner(Pruner):
mask = self.cal_mask(var_name, pruned_ratio, group_dict) mask = self.cal_mask(var_name, pruned_ratio, group_dict)
for _name in group_dict: for _name in group_dict:
dims = group_dict[_name]['pruned_dims'] # Varibales can be pruned on multiple axies.
stride = group_dict[_name]['stride'] for _item in group_dict[_name]:
var_shape = group_dict[_name]['var'].shape dims = _item['pruned_dims']
transforms = _item['transforms']
var_shape = _item['var'].shape
if isinstance(dims, int): if isinstance(dims, int):
dims = [dims] dims = [dims]
for trans in transforms:
current_mask = mask.repeat(stride[0]) if stride[0] > 1 else mask mask = self._transform_mask(mask, trans)
current_mask = mask
assert len(current_mask) == var_shape[dims[ assert len(current_mask) == var_shape[dims[
0]], "The length of current_mask must be equal to the size of dimension to be pruned on." 0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
plan.apply(self.model, lazy=False) plan.apply(self.model, lazy=False)
return plan return plan
def _transform_mask(self, mask, transform):
src_start = transform['src_start']
src_end = transform['src_end']
target_start = transform['target_start']
target_end = transform['target_end']
target_len = transform['target_len']
stride = transform['stride']
mask = mask[src_start:src_end]
mask = mask.repeat(stride) if stride > 1 else mask
dst_mask = np.ones([target_len])
# for depthwise conv2d with:
# input shape: (1, 4, 32, 32)
# filter shape: (32, 1, 3, 3)
# groups: 4
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand
return dst_mask
...@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO)
class FPGMFilterPruner(FilterPruner): class FPGMFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(FPGMFilterPruner, self).__init__( super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file)
model, input_shape, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value'] for _item in group[var_name]:
pruned_dims = group[var_name]['pruned_dims'] if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
assert (pruned_dims == [0]) assert (pruned_dims == [0])
dist_sum_list = [] dist_sum_list = []
......
...@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner): class L1NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file) model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value'] for _item in group[var_name]:
pruned_dims = group[var_name]['pruned_dims'] if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [ reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims i for i in range(len(value.shape)) if i not in pruned_dims
] ]
......
...@@ -12,13 +12,16 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,13 +12,16 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file) model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value'] # find information of pruning on output channels
pruned_dims = group[var_name]['pruned_dims'] for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [ reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims i for i in range(len(value.shape)) if i not in pruned_dims
] ]
......
...@@ -19,9 +19,9 @@ class Pruner(object): ...@@ -19,9 +19,9 @@ class Pruner(object):
""" """
def __init__(self, model, input_shape): def __init__(self, model, inputs):
self.model = model self.model = model
self.input_shape = input_shape self.inputs = inputs
self._var_shapes = {} self._var_shapes = {}
for var in model.parameters(): for var in model.parameters():
self._var_shapes[var.name] = var.shape self._var_shapes[var.name] = var.shape
...@@ -53,5 +53,5 @@ class Pruner(object): ...@@ -53,5 +53,5 @@ class Pruner(object):
global_plan.apply(self.model, lazy=True) global_plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
global_plan.apply(self.model, lazy=False) global_plan.apply(self.model, lazy=False)
self.plan = global_plan
return global_plan return global_plan
...@@ -28,7 +28,7 @@ class PruningMask(): ...@@ -28,7 +28,7 @@ class PruningMask():
if self._mask is not None: if self._mask is not None:
assert len(self._mask.shape) == len( assert len(self._mask.shape) == len(
value value
), "The length of value must be same with shape of mask in current PruningMask instance." ), "The length of value must be same with length of mask's shape in current PruningMask instance."
self._dims = list(value) self._dims = list(value)
@property @property
...@@ -37,11 +37,6 @@ class PruningMask(): ...@@ -37,11 +37,6 @@ class PruningMask():
@mask.setter @mask.setter
def mask(self, value): def mask(self, value):
assert (isinstance(value, PruningMask))
if self._dims is not None:
assert len(self._mask.shape) == len(
value
), "The length of value must be same with shape of mask in current PruningMask instance."
self._mask = value self._mask = value
def __str__(self): def __str__(self):
...@@ -71,12 +66,20 @@ class PruningPlan(): ...@@ -71,12 +66,20 @@ class PruningPlan():
self._pruned_flops = value self._pruned_flops = value
def add(self, var_name, pruning_mask): def add(self, var_name, pruning_mask):
assert (isinstance(pruning_mask, PruningMask)) assert (isinstance(pruning_mask, PruningMask))
if var_name not in self._masks: if var_name not in self._masks:
self._masks[var_name] = [] self._masks[var_name] = []
self._masks[var_name].append(pruning_mask)
if var_name not in self._dims: if var_name not in self._dims:
self._dims[var_name] = [] self._dims[var_name] = []
if pruning_mask.dims in self._dims[var_name]:
for _mask in self._masks[var_name]:
if pruning_mask.dims == _mask.dims:
_mask.mask = list(
np.array(_mask.mask) | np.array(pruning_mask.mask))
else:
self._masks[var_name].append(pruning_mask)
self._dims[var_name].append(pruning_mask.dims) self._dims[var_name].append(pruning_mask.dims)
@property @property
...@@ -87,7 +90,6 @@ class PruningPlan(): ...@@ -87,7 +90,6 @@ class PruningPlan():
assert (isinstance(plan, PruningPlan)) assert (isinstance(plan, PruningPlan))
for var_name in plan.masks: for var_name in plan.masks:
for mask in plan.masks[var_name]: for mask in plan.masks[var_name]:
if not self.contains(var_name, mask.dims):
self.add(var_name, mask) self.add(var_name, mask)
def contains(self, var_name, dims=None): def contains(self, var_name, dims=None):
...@@ -172,7 +174,6 @@ class PruningPlan(): ...@@ -172,7 +174,6 @@ class PruningPlan():
bool_mask = mask.astype(bool) bool_mask = mask.astype(bool)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value) lambda data: data[bool_mask], dims[0], value)
p = t_value._place() p = t_value._place()
if p.is_cpu_place(): if p.is_cpu_place():
place = paddle.CPUPlace() place = paddle.CPUPlace()
...@@ -184,14 +185,19 @@ class PruningPlan(): ...@@ -184,14 +185,19 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id()) place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place) t_value.set(pruned_value, place)
if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D): if isinstance(
if sub_layer._groups > 1 and pruned_value.shape[ sub_layer, paddle.nn.layer.conv.Conv2D
1] == 1: # depthwise conv2d ) and sub_layer._groups > 1 and len(param.shape) == 4:
assert param.shape[
1] == 1, "It just supports depthwise conv2d when groups > 1."
new_groups = int(bool_mask.sum() *
sub_layer._groups / len(bool_mask))
_logger.debug( _logger.debug(
"Update groups of depthwise conv2d form {} to {}". "Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups, format(sub_layer._groups, new_groups))
pruned_value.shape[0])) sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = pruned_value.shape[0] sub_layer._groups = new_groups
# for training # for training
if param.trainable: if param.trainable:
param.clear_gradient() param.clear_gradient()
...@@ -218,11 +224,6 @@ class PruningPlan(): ...@@ -218,11 +224,6 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id()) place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(np.array(t_backup).astype("float32"), place) t_value.set(np.array(t_backup).astype("float32"), place)
if "_origin_groups" in sub_layer.__dict__:
if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D): sub_layer._groups = sub_layer._origin_groups
if sub_layer._groups > 1:
_logger.debug(
"Update groups of conv form {} to {}".format(
sub_layer._groups, t_value.shape()[0]))
sub_layer._groups = t_value.shape()[0]
del sub_layer._buffers[backup_name] del sub_layer._buffers[backup_name]
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import logging import logging
import paddle import paddle
from paddle.fluid.dygraph import TracedLayer from paddle.fluid.dygraph import TracedLayer
from ..core import GraphWrapper from ..core import GraphWrapper, dygraph2program
from ..prune import collect_convs from ..prune import collect_convs
from ..common import get_logger from ..common import get_logger
...@@ -12,33 +12,43 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -12,33 +12,43 @@ _logger = get_logger(__name__, level=logging.INFO)
class VarGroup(): class VarGroup():
def __init__(self, model, input_shape): """
A tool used to parse dygraph and store information of variables' relationship.
Args:
- model(nn.Layer): The dygraph to be parsed.
- inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`.
"""
def __init__(self, model, inputs):
self.groups = [] self.groups = []
self._parse_model(model, input_shape) self._parse_model(model, inputs)
def _to_dict(self, group): def _to_dict(self, group):
ret = {} ret = {}
for _name, _axis, _stride in group: for _name, _axis, _transforms in group:
if isinstance(_axis, int): if isinstance(_axis, int):
_axis = [_axis] # TODO: fix _axis = [_axis]
ret[_name] = {'pruned_dims': _axis, 'stride': _stride} if _name not in ret:
ret[_name] = []
# Variable can be pruned on multiple axies.
ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms})
return ret return ret
def find_group(self, var_name, axis): def find_group(self, var_name, axis):
for group in self.groups: for group in self.groups:
for _name, _axis, _stride in group: for _name, _axis, _stride in group:
if isinstance(_axis, int): if isinstance(_axis, int):
_axis = [_axis] # TODO: fix _axis = [_axis]
if _name == var_name and _axis == axis: if _name == var_name and _axis == axis:
return self._to_dict(group) return self._to_dict(group)
def _parse_model(self, model, input_shape): def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(input_shape)) _logger.debug("Parsing model with input: {}".format(inputs))
data = np.ones(tuple(input_shape)).astype("float32")
in_var = paddle.to_tensor(data)
model.eval() model.eval()
out_dygraph, static_layer = TracedLayer.trace(model, inputs=[in_var]) program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(static_layer.program)
graph = GraphWrapper(program)
visited = {} visited = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
......
...@@ -79,7 +79,7 @@ def collect_convs(params, graph, visited={}): ...@@ -79,7 +79,7 @@ def collect_convs(params, graph, visited={}):
pruned_params=pruned_params, pruned_params=pruned_params,
visited=visited) visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0]) walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params) groups.append(pruned_params)
visited = set() visited = set()
uniq_groups = [] uniq_groups = []
...@@ -96,5 +96,4 @@ def collect_convs(params, graph, visited={}): ...@@ -96,5 +96,4 @@ def collect_convs(params, graph, visited={}):
simple_group.append((param, axis, pruned_idx)) simple_group.append((param, axis, pruned_idx))
if not repeat_group: if not repeat_group:
uniq_groups.append(simple_group) uniq_groups.append(simple_group)
return uniq_groups return uniq_groups
...@@ -59,11 +59,9 @@ def default_idx_selector(group, ratio): ...@@ -59,11 +59,9 @@ def default_idx_selector(group, ratio):
pruned_num = int(round(len(sorted_idx) * ratio)) pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
idxs = [] idxs = []
for name, axis, score, offsets in group: for name, axis, score, transforms in group:
r_idx = [i + offsets[0] for i in pruned_idx] idxs.append((name, axis, pruned_idx, transforms))
idxs.append((name, axis, r_idx))
return idxs return idxs
...@@ -112,6 +110,6 @@ def optimal_threshold(group, ratio): ...@@ -112,6 +110,6 @@ def optimal_threshold(group, ratio):
pruned_idx = np.squeeze(np.argwhere(score < th)) pruned_idx = np.squeeze(np.argwhere(score < th))
idxs = [] idxs = []
for name, axis, score, _ in group: for name, axis, score, transforms in group:
idxs.append((name, axis, pruned_idx)) idxs.append((name, axis, pruned_idx, transforms))
return idxs return idxs
...@@ -10,6 +10,7 @@ __all__ = ["save_model", "load_model"] ...@@ -10,6 +10,7 @@ __all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
_SHAPES_FILE = "__shapes__" _SHAPES_FILE = "__shapes__"
_GROUPS_FILE = "__groups__"
def save_model(exe, graph, dirname): def save_model(exe, graph, dirname):
...@@ -39,6 +40,17 @@ def save_model(exe, graph, dirname): ...@@ -39,6 +40,17 @@ def save_model(exe, graph, dirname):
json.dump(shapes, f) json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE)) _logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
groups = {}
for op in graph.ops():
if 'conv2d' in op.type():
filter_name = op.inputs('Filter')[0].name()
groups[filter_name] = op.attr('groups')
GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE)
with open(GROUPS_FILE, "w") as f:
json.dump(groups, f)
_logger.info("Save groups of cnov2d into {}".format(GROUPS_FILE))
def load_model(exe, graph, dirname): def load_model(exe, graph, dirname):
""" """
...@@ -53,7 +65,6 @@ def load_model(exe, graph, dirname): ...@@ -53,7 +65,6 @@ def load_model(exe, graph, dirname):
paddle.static.Program) else graph paddle.static.Program) else graph
SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE) SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f: with open(SHAPES_FILE, "r") as f:
shapes = json.load(f) shapes = json.load(f)
for param_name, shape in shapes.items(): for param_name, shape in shapes.items():
...@@ -62,9 +73,17 @@ def load_model(exe, graph, dirname): ...@@ -62,9 +73,17 @@ def load_model(exe, graph, dirname):
param.set_shape(shape) param.set_shape(shape)
else: else:
_logger.info('{} is not loaded'.format(param_name)) _logger.info('{} is not loaded'.format(param_name))
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) _logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE)
with open(GROUPS_FILE, "r") as f:
groups = json.load(f)
for op in graph.ops():
if 'conv2d' in op.type():
filter_name = op.inputs('Filter')[0].name()
op.set_attr('groups', groups[filter_name])
_logger.info("Load groups of conv2d from {}".format(GROUPS_FILE))
paddle.static.load(program=graph.program, model_path=dirname, executor=exe) paddle.static.load(program=graph.program, model_path=dirname, executor=exe)
graph.update_groups_of_conv()
graph.infer_shape() graph.infer_shape()
_logger.info("Load weights from {}".format(dirname)) _logger.info("Load weights from {}".format(dirname))
...@@ -65,6 +65,15 @@ class PruneWorker(object): ...@@ -65,6 +65,15 @@ class PruneWorker(object):
self.visited[pruned_axis][key] = True self.visited[pruned_axis][key] = True
return True return True
def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis)
pre_ops = var.inputs()
for op in pre_ops:
self._prune_op(op, var, axis, transforms)
next_ops = var.outputs()
for op in next_ops:
self._prune_op(op, var, axis, transforms)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
...@@ -85,6 +94,9 @@ class PruneWorker(object): ...@@ -85,6 +94,9 @@ class PruneWorker(object):
cls = PRUNE_WORKER.get("default_walker") cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name())) self.op, op, pruned_axis, var.name()))
_logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
)
walker = cls(op, pruned_params=self.pruned_params, visited=self.visited) walker = cls(op, pruned_params=self.pruned_params, visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx) walker.prune(var, pruned_axis, pruned_idx)
...@@ -170,11 +182,6 @@ class conv2d(PruneWorker): ...@@ -170,11 +182,6 @@ class conv2d(PruneWorker):
self.pruned_params.append( self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx)) (self.op.inputs("Bias")[0], channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class conv2d_transpose(PruneWorker): class conv2d_transpose(PruneWorker):
...@@ -250,6 +257,12 @@ class batch_norm(PruneWorker): ...@@ -250,6 +257,12 @@ class batch_norm(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class sync_batch_norm(batch_norm):
def __init__(self, op, pruned_params, visited):
super(sync_batch_norm, self).__init__(op, pruned_params, visited)
class elementwise_op(PruneWorker): class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited) super(elementwise_op, self).__init__(op, pruned_params, visited)
...@@ -269,9 +282,12 @@ class elementwise_op(PruneWorker): ...@@ -269,9 +282,12 @@ class elementwise_op(PruneWorker):
in_var = self.op.inputs(name)[0] in_var = self.op.inputs(name)[0]
if len(in_var.shape()) == 1 and in_var.shape()[0] == 1: if len(in_var.shape()) == 1 and in_var.shape()[0] == 1:
continue continue
pre_ops = in_var.inputs()
for op in pre_ops: # for bias
self._prune_op(op, in_var, actual_axis, pruned_idx) if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.pruned_params.append((in_var, actual_axis, pruned_idx))
self._visit_and_search(in_var, actual_axis, pruned_idx)
else: else:
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -287,24 +303,17 @@ class elementwise_op(PruneWorker): ...@@ -287,24 +303,17 @@ class elementwise_op(PruneWorker):
in_var.shape()[0] == 1): in_var.shape()[0] == 1):
self.pruned_params.append( self.pruned_params.append(
(in_var, y_pruned_axis, pruned_idx)) (in_var, y_pruned_axis, pruned_idx))
pre_ops = in_var.inputs() self._visit_and_search(in_var, y_pruned_axis, pruned_idx)
for op in pre_ops:
self._prune_op(op, in_var, y_pruned_axis, pruned_idx)
elif var in self.op.inputs("Y"): elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
if len(in_var.shape()) != len(var.shape()): if len(in_var.shape()) != len(var.shape()):
assert (len(var.shape()) < len(in_var.shape())) assert (len(var.shape()) < len(in_var.shape()))
pruned_axis = pruned_axis + axis pruned_axis = pruned_axis + axis
if pruned_axis <= len(in_var.shape()): if pruned_axis <= len(in_var.shape()):
pre_ops = in_var.inputs() self._visit_and_search(in_var, pruned_axis, pruned_idx)
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis) self._visit_and_search(out_var, pruned_axis, pruned_idx)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -447,53 +456,70 @@ class concat(PruneWorker): ...@@ -447,53 +456,70 @@ class concat(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited) super(concat, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
idx = []
axis = self.op.attr("axis") axis = self.op.attr("axis")
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
self._visit(var, pruned_axis)
start = 0 start = 0
if axis == pruned_axis: if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")): for _, in_var in enumerate(self.op.inputs("X")):
idx = [] idx = []
for i in pruned_idx: transoform = {
r_idx = i - start 'src_start': start,
if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0: 'src_end': start + in_var.shape()[pruned_axis],
idx.append(r_idx) 'target_start': 0,
'target_end': in_var.shape()[pruned_axis],
'target_len': in_var.shape()[pruned_axis],
'stride': 1
}
start += in_var.shape()[pruned_axis] start += in_var.shape()[pruned_axis]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, idx) self._prune_op(op, in_var, pruned_axis,
idx = pruned_idx[:] transforms + [transoform])
else: else:
for _, in_var in enumerate(self.op.inputs("X")): for _, in_var in enumerate(self.op.inputs("X")):
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs() pre_ops = in_var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx) self._prune_op(op, in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"): elif var in self.op.inputs("X"):
self._visit(var, pruned_axis)
if axis == pruned_axis: if axis == pruned_axis:
idx = [] idx = []
start = 0 target_start = 0
for v in self.op.inputs("X"): for v in self.op.inputs("X"):
if v.name() == var.name(): if v.name() != var.name():
idx = [i + start for i in pruned_idx] target_start += v.shape()[pruned_axis]
else: else:
start += v.shape()[pruned_axis] break
target_end = target_start + v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={}) transform = {
else: 'src_start': 0,
for v in self.op.inputs("X"): 'src_end': var.shape()[pruned_axis],
for op in v.inputs(): 'target_start': target_start,
self._prune_op(op, v, pruned_axis, pruned_idx) 'target_end': target_end,
out_var = self.op.outputs("Out")[0] 'target_len': out_var.shape()[pruned_axis],
'stride': 1
}
self._visit(out_var, pruned_axis) self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) # The output of concat can be visited repeatedly
c_visited = {}
self._prune_op(
op,
out_var,
pruned_axis,
transforms + [transform],
visited=c_visited)
# Add nodes searched from concat into global visited array.
self.visited.update(c_visited)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -501,8 +527,14 @@ class depthwise_conv2d(PruneWorker): ...@@ -501,8 +527,14 @@ class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited) super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
assert var not in self.op.inputs(
"Filter"), "Unsupport for pruning depthwise conv2d directly."
assert var not in self.op.outputs(
"Output"
), "Unsupport for pruning output of depthwise conv2d directly."
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
...@@ -510,60 +542,28 @@ class depthwise_conv2d(PruneWorker): ...@@ -510,60 +542,28 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis) pruned_axis)
groups = var.shape()[channel_axis]
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx)) transform = {
self._visit(filter_var, 0) 'src_start': 0,
'src_end': var.shape()[pruned_axis],
for op in filter_var.outputs(): 'target_start': 0,
self._prune_op(op, filter_var, 0, pruned_idx) 'target_end': filter_var.shape()[0],
'target_len': filter_var.shape()[0],
output_var = self.op.outputs("Output")[0] 'stride': 1
next_ops = output_var.outputs() }
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx) self.pruned_params.append((filter_var, 0, transforms + [transform]))
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0]
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.pruned_params.append((var, 0, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0) self._visit(filter_var, 0)
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx) self._prune_op(op, filter_var, 0, transforms + [transform])
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs() next_ops = output_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx) self._prune_op(op, output_var, channel_axis,
transforms + [transform])
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -679,7 +679,7 @@ class flatten_contiguous_range(PruneWorker): ...@@ -679,7 +679,7 @@ class flatten_contiguous_range(PruneWorker):
super(flatten_contiguous_range, self).__init__(op, pruned_params, super(flatten_contiguous_range, self).__init__(op, pruned_params,
visited) visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis") start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis") stop_axis = self.op.attr("stop_axis")
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -687,7 +687,6 @@ class flatten_contiguous_range(PruneWorker): ...@@ -687,7 +687,6 @@ class flatten_contiguous_range(PruneWorker):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
stride = 1 stride = 1
out_pruned_axis = pruned_axis out_pruned_axis = pruned_axis
out_pruned_idx = pruned_idx
if pruned_axis >= start_axis and pruned_axis <= stop_axis: if pruned_axis >= start_axis and pruned_axis <= stop_axis:
out_pruned_axis = start_axis out_pruned_axis = start_axis
for i in range(pruned_axis + 1, stop_axis + 1): for i in range(pruned_axis + 1, stop_axis + 1):
...@@ -697,7 +696,8 @@ class flatten_contiguous_range(PruneWorker): ...@@ -697,7 +696,8 @@ class flatten_contiguous_range(PruneWorker):
self._visit(in_var, pruned_axis) self._visit(in_var, pruned_axis)
self._visit(out_var, out_pruned_axis) self._visit(out_var, out_pruned_axis)
transform = {'stride': stride}
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis, [stride]) self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform])
...@@ -85,8 +85,8 @@ class Pruner(): ...@@ -85,8 +85,8 @@ class Pruner():
param_backup = {} if param_backup else None param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None param_shape_backup = {} if param_shape_backup else None
visited = {}
pruned_params = [] pruned_params = []
visited = {}
for param, ratio in zip(params, ratios): for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param)) _logger.info("pruning: {}".format(param))
if graph.var(param) is None: if graph.var(param) is None:
...@@ -98,15 +98,6 @@ class Pruner(): ...@@ -98,15 +98,6 @@ class Pruner():
visited)[0] # [(name, axis, pruned_idx)] visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0: if group is None or len(group) == 0:
continue continue
if only_graph and self.idx_selector.__name__ == "default_idx_selector":
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, axis, _ in group:
pruned_params.append((name, axis, pruned_idx))
else:
assert ((not self.pruned_weights), assert ((not self.pruned_weights),
"The weights have been pruned once.") "The weights have been pruned once.")
group_values = [] group_values = []
...@@ -116,10 +107,10 @@ class Pruner(): ...@@ -116,10 +107,10 @@ class Pruner():
values = np.array(var.get_tensor()) values = np.array(var.get_tensor())
group_values.append((name, values, axis, pruned_idx)) group_values.append((name, values, axis, pruned_idx))
scores = self.criterion( scores = self.criterion(group_values,
group_values, graph) # [(name, axis, score, pruned_idx)] graph) # [(name, axis, score, pruned_idx)]
g = self._transform(self.idx_selector(scores, ratio))
pruned_params.extend(self.idx_selector(scores, ratio)) pruned_params.extend(g)
merge_pruned_params = {} merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params: for param, pruned_axis, pruned_idx in pruned_params:
...@@ -128,7 +119,6 @@ class Pruner(): ...@@ -128,7 +119,6 @@ class Pruner():
if pruned_axis not in merge_pruned_params[param]: if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = [] merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx) merge_pruned_params[param][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params: for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]: for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][ pruned_idx = np.concatenate(merge_pruned_params[param_name][
...@@ -138,12 +128,26 @@ class Pruner(): ...@@ -138,12 +128,26 @@ class Pruner():
_logger.debug("{}\t{}\t{}\t{}".format( _logger.debug("{}\t{}\t{}\t{}".format(
param.name(), pruned_axis, param.name(), pruned_axis,
param.shape()[pruned_axis], len(pruned_idx))) param.shape()[pruned_axis], len(pruned_idx)))
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape()) origin_shape = copy.deepcopy(param.shape())
if param_shape_backup is not None:
param_shape_backup[param.name()] = origin_shape param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx) new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape) param.set_shape(new_shape)
# update groups of depthwise conv2d
for op in param.outputs():
if op.type() in ["conv2d", "depthwise_conv2d"
] and op.attr("groups") > 1:
assert origin_shape[
1] == 1, "Only support for depthwise when groups > 1."
new_groups = int(
op.attr("groups") * new_shape[pruned_axis] /
origin_shape[pruned_axis])
_logger.debug(
f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}"
)
op.set_attr("groups", new_groups)
if not only_graph: if not only_graph:
param_t = scope.find_var(param.name()).get_tensor() param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and ( if param_backup is not None and (
...@@ -156,16 +160,35 @@ class Pruner(): ...@@ -156,16 +160,35 @@ class Pruner():
pruned_idx, pruned_idx,
pruned_axis=pruned_axis, pruned_axis=pruned_axis,
lazy=lazy) lazy=lazy)
param_t.set(pruned_param, place)
except IndexError as e: except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format( _logger.error("Pruning {}, but get [{}]".format(
param.name(), e)) param.name(), e))
param_t.set(pruned_param, place)
graph.update_groups_of_conv()
graph.infer_shape() graph.infer_shape()
self.pruned_weights = (not only_graph) self.pruned_weights = (not only_graph)
return graph.program, param_backup, param_shape_backup return graph.program, param_backup, param_shape_backup
def _transform(self, group):
ret = []
for name, axis, pruned_idx, transforms in group:
src = pruned_idx
for trans in transforms:
src_start = trans['src_start']
src_end = trans['src_end']
target_start = trans['target_start']
target_end = trans['target_end']
target = []
for idx in src:
if idx >= src_start and idx < src_end:
idx -= src_start
idx += target_start
if idx < target_end:
target.append(idx)
src = target
ret.append((name, axis, src))
return ret
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
""" """
Pruning a array by indexes on given axis. Pruning a array by indexes on given axis.
......
import sys import sys
sys.path.append("../../") sys.path.append("../../")
import unittest import unittest
import numpy as np
import paddle
from paddleslim.analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
from paddle.vision.models import mobilenet_v1, resnet50 from paddle.vision.models import mobilenet_v1, resnet50
from paddle.nn import Conv2D, Layer
class TestFlops(unittest.TestCase): class TestFlops(unittest.TestCase):
...@@ -17,9 +20,64 @@ class TestFlops(unittest.TestCase): ...@@ -17,9 +20,64 @@ class TestFlops(unittest.TestCase):
self.assertTrue(FLOPs == self._gt) self.assertTrue(FLOPs == self._gt)
class Net1(Layer):
def __init__(self):
super(Net1, self).__init__()
self.conv1 = Conv2D(3, 2, 3)
self.conv2 = Conv2D(3, 2, 3)
def forward(self, inputs):
assert isinstance(inputs, dict)
x = inputs["x"]
y = inputs["y"]
return {"x": self.conv1(x), "y": self.conv2(y), "dummy": "dummy"}
class Net2(Net1):
def __init__(self):
super(Net2, self).__init__()
def forward(self, x, y):
return [self.conv1(x), self.conv2(y), "dummy"]
class TestFLOPsCase1(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
y_shape = (1, 3, 16, 16)
net = Net1()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = {
"x": paddle.to_tensor(x),
"y": paddle.to_tensor(y),
"z": "test"
}
FLOPs = flops(net, [inputs])
self.assertTrue(FLOPs == 59184)
class TestFLOPsCase2(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
y_shape = (1, 3, 16, 16)
net = Net2()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = [paddle.to_tensor(x), paddle.to_tensor(y)]
FLOPs1 = flops(net, inputs)
shapes = [x_shape, y_shape]
FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"])
self.assertTrue(FLOPs1 == FLOPs2)
def add_cases(suite): def add_cases(suite):
suite.addTest(TestFlops(net=mobilenet_v1, gt=11792896.0)) suite.addTest(TestFlops(net=mobilenet_v1, gt=11792896.0))
suite.addTest(TestFlops(net=resnet50, gt=83872768.0)) suite.addTest(TestFlops(net=resnet50, gt=83872768.0))
suite.addTest(TestFLOPsCase1())
suite.addTest(TestFLOPsCase2())
def load_tests(loader, standard_tests, pattern): def load_tests(loader, standard_tests, pattern):
......
...@@ -47,6 +47,7 @@ class TestPrune(unittest.TestCase): ...@@ -47,6 +47,7 @@ class TestPrune(unittest.TestCase):
shapes = {} shapes = {}
for param in model.parameters(): for param in model.parameters():
shapes[param.name] = param.shape shapes[param.name] = param.shape
pruner.restore()
return shapes return shapes
def static_prune(self, net, ratios): def static_prune(self, net, ratios):
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
...@@ -23,7 +24,8 @@ def conv_bn_layer(input, ...@@ -23,7 +24,8 @@ def conv_bn_layer(input,
groups=1, groups=1,
act=None, act=None,
bias=False, bias=False,
use_cudnn=True): use_cudnn=True,
sync_bn=False):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
...@@ -37,6 +39,14 @@ def conv_bn_layer(input, ...@@ -37,6 +39,14 @@ def conv_bn_layer(input,
name=name + "_out", name=name + "_out",
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
bn_name = name + "_bn" bn_name = name + "_bn"
if sync_bn:
bn = paddle.nn.SyncBatchNorm(
num_filters,
weight_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
name=bn_name)
return bn(conv)
else:
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
......
import sys
sys.path.append("../")
import unittest
import numpy as np
from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase):
def testAdd(self):
plan = PruningPlan()
mask = PruningMask([0], [0, 0, 1], 0.33)
plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33)
plan.add("a", mask)
a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 1])
self.assertTrue(a_mask[0].dims == [0])
if __name__ == '__main__':
unittest.main()
...@@ -41,14 +41,30 @@ class TestPrune(StaticCase): ...@@ -41,14 +41,30 @@ class TestPrune(StaticCase):
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs( collected_groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program) ["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
while [] in groups: while [] in collected_groups:
groups.remove([]) collected_groups.remove([])
print(groups) print(collected_groups)
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 20) params = set([
self.assertTrue(len(groups[1]) == 6) param.name for param in main_program.all_parameters()
if "weights" in param.name
])
expected_groups = [[('conv1_weights', 0), ('conv2_weights', 1),
('conv2_weights', 0), ('conv3_weights', 1),
('conv4_weights', 0), ('conv5_weights', 1)],
[('conv3_weights', 0), ('conv4_weights', 1)]]
self.assertTrue(len(collected_groups) == len(expected_groups))
for _collected, _expected in zip(collected_groups, expected_groups):
for _name, _axis, _ in _collected:
if _name in params:
self.assertTrue((_name, _axis) in _expected)
for _name, _axis in _expected:
if _name in params:
self.assertTrue((_name, _axis, []) in _collected)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
from static_case import StaticCase
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from static_case import StaticCase
from layers import conv_bn_layer
class TestPrune(StaticCase):
def test_concat(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X
# conv1 conv2-->concat conv3-->sum-->out
# | ^ | ^
# |____________| |____________________|
#
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(input, 8, 3, "conv2", sync_bn=True)
tmp = fluid.layers.concat([conv1, conv2], axis=1)
conv3 = conv_bn_layer(input, 16, 3, "conv3", bias=None)
out = conv3 + tmp
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv3_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv3_weights": (8, 3, 3, 3),
"conv2_weights": (4, 3, 3, 3),
"conv1_weights": (4, 3, 3, 3)
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
# test forward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv1_weights", "conv2_weights"],
ratios=[0.5, 0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4, 3, 3, 3),
"conv1_bn_scale": (4, ),
"conv1_bn_variance": (4, ),
"conv1_bn_mean": (4, ),
"conv1_bn_offset": (4, ),
"conv2_weights": (4, 3, 3, 3),
"sync_batch_norm_0.w_0": (4, ),
"sync_batch_norm_0.w_1": (4, ),
"conv2_bn_scale": (4, ),
"conv2_bn_offset": (4, ),
"conv3_weights": (8, 3, 3, 3),
"conv3_bn_mean": (8, ),
"conv3_bn_offset": (8, ),
"conv3_bn_scale": (8, ),
"conv3_bn_variance": (8, ),
"conv3_out.b_0": (8, ),
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册