未验证 提交 b849d609 编写于 作者: W whs 提交者: GitHub

Fix pruning group conv2d (#720)

上级 d3aeda6f
...@@ -197,7 +197,10 @@ class OpWrapper(object): ...@@ -197,7 +197,10 @@ class OpWrapper(object):
bool|int|str|float|list: The attribute value. The return value bool|int|str|float|list: The attribute value. The return value
can be any valid attribute type. can be any valid attribute type.
""" """
return self._op.attr(name) if self._op.has_attr(name):
return self._op.attr(name)
else:
return None
class GraphWrapper(object): class GraphWrapper(object):
...@@ -365,35 +368,6 @@ class GraphWrapper(object): ...@@ -365,35 +368,6 @@ class GraphWrapper(object):
Update the groups of convolution layer according to current filters. Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file. It is used after loading pruned parameters from file.
""" """
head_op = []
visited = []
for op in self.ops(): for op in self.ops():
if op.type() != 'conditional_block': if op.type() != 'conditional_block':
if len(self.pre_ops(op)) == 0:
head_op.append(op)
candidate_op = self.ops()
def recursive_infer(op, infer=False):
if op in candidate_op:
if op.type() != 'conditional_block':
if infer:
op._op.desc.infer_shape(op._op.block.desc)
else:
visited.append(op)
candidate_op.remove(op)
for next_op in self.next_ops(op):
recursive_infer(next_op)
# Find ops which not in the DAG, some ops, such as optimizer op,
# should be infered before normal cumputation ops.
for op in head_op:
recursive_infer(op, infer=False)
# Infer ops which not in the DAG firstly.
candidate_op = self.ops()
for op in candidate_op:
if op not in visited and op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc) op._op.desc.infer_shape(op._op.block.desc)
# Infer the remain ops in topological order.
for op in head_op:
recursive_infer(op, infer=True)
...@@ -9,14 +9,14 @@ from .var_group import * ...@@ -9,14 +9,14 @@ from .var_group import *
from .pruning_plan import * from .pruning_plan import *
from .pruner import Pruner from .pruner import Pruner
from paddleslim.analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
from .var_group import VarGroup from .var_group import DygraphPruningCollections
__all__ = ['Status', 'FilterPruner'] __all__ = ['Status', 'FilterPruner']
_logger = get_logger(__name__, logging.INFO) _logger = get_logger(__name__, logging.INFO)
CONV_OP_TYPE = paddle.nn.Conv2D CONV_OP_TYPE = paddle.nn.Conv2D
FILTER_DIM = [0] FILTER_DIM = 0
CONV_WEIGHT_NAME = "weight" CONV_WEIGHT_NAME = "weight"
SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose) SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose)
...@@ -59,16 +59,17 @@ class FilterPruner(Pruner): ...@@ -59,16 +59,17 @@ class FilterPruner(Pruner):
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(FilterPruner, self).__init__(model, inputs) super(FilterPruner, self).__init__(model, inputs)
self._status = Status(sen_file) self._status = Status(sen_file)
# sensitive and var_group are just used in filter pruning # sensitive and collections are just used in filter pruning
self.var_group = VarGroup(model, inputs) self.collections = DygraphPruningCollections(model, inputs)
# skip vars in: # skip vars in:
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
self.skip_vars = [] self.skip_vars = []
for sub_layer in model.sublayers(): for sub_layer in model.sublayers():
if isinstance(sub_layer, SKIP_LAYERS) or (isinstance( #if isinstance(sub_layer, SKIP_LAYERS) or (isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D) and # sub_layer, paddle.nn.layer.conv.Conv2D) and
sub_layer._groups > 1): # sub_layer._groups > 1):
if isinstance(sub_layer, SKIP_LAYERS):
for param in sub_layer.parameters(): for param in sub_layer.parameters():
self.skip_vars.append(param.name) self.skip_vars.append(param.name)
...@@ -170,11 +171,11 @@ class FilterPruner(Pruner): ...@@ -170,11 +171,11 @@ class FilterPruner(Pruner):
break break
return ratios return ratios
def _round_to(self, ratios, dims=[0], factor=8): def _round_to(self, ratios, dims=0, factor=8):
ret = {} ret = {}
for name in ratios: for name in ratios:
ratio = ratios[name] ratio = ratios[name]
dim = self._var_shapes[name][dims[0]] dim = self._var_shapes[name][dims]
remained = round((1 - ratio) * dim / factor) * factor remained = round((1 - ratio) * dim / factor) * factor
if remained == 0: if remained == 0:
remained = factor remained = factor
...@@ -186,14 +187,14 @@ class FilterPruner(Pruner): ...@@ -186,14 +187,14 @@ class FilterPruner(Pruner):
def get_ratios_by_sensitivity(self, def get_ratios_by_sensitivity(self,
pruned_flops, pruned_flops,
align=None, align=None,
dims=[0], dims=0,
skip_vars=[]): skip_vars=[]):
""" """
Get a group of ratios by sensitivities. Get a group of ratios by sensitivities.
Args: Args:
pruned_flops(float): The excepted rate of FLOPs to be pruned. It should be in range (0, 1). pruned_flops(float): The excepted rate of FLOPs to be pruned. It should be in range (0, 1).
align(int, optional): Round the size of each pruned dimension to multiple of 'align' if 'align' is not None. Default: None. align(int, optional): Round the size of each pruned dimension to multiple of 'align' if 'align' is not None. Default: None.
dims(list, optional): The dims to be pruned on. [0] means pruning channels of output for convolution. Default: [0]. dims(int, optional): The dims to be pruned on. 0 means pruning channels of output for convolution. Default: 0.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None. skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None.
Returns: Returns:
...@@ -201,7 +202,7 @@ class FilterPruner(Pruner): ...@@ -201,7 +202,7 @@ class FilterPruner(Pruner):
""" """
base_flops = flops(self.model, self.inputs) base_flops = flops(self.model, self.inputs)
_logger.debug("Base FLOPs: {}".format(base_flops)) _logger.info("Base FLOPs: {}".format(base_flops))
low = 0. low = 0.
up = 1.0 up = 1.0
history = set() history = set()
...@@ -214,7 +215,6 @@ class FilterPruner(Pruner): ...@@ -214,7 +215,6 @@ class FilterPruner(Pruner):
ratios = self._round_to(ratios, dims=dims, factor=align) ratios = self._round_to(ratios, dims=dims, factor=align)
plan = self.prune_vars(ratios, axis=dims) plan = self.prune_vars(ratios, axis=dims)
c_flops = flops(self.model, self.inputs) c_flops = flops(self.model, self.inputs)
_logger.debug("FLOPs after pruning: {}".format(c_flops))
c_pruned_flops = (base_flops - c_flops) / base_flops c_pruned_flops = (base_flops - c_flops) / base_flops
plan.restore(self.model) plan.restore(self.model)
_logger.debug("Seaching ratios, pruned FLOPs: {}".format( _logger.debug("Seaching ratios, pruned FLOPs: {}".format(
...@@ -240,10 +240,9 @@ class FilterPruner(Pruner): ...@@ -240,10 +240,9 @@ class FilterPruner(Pruner):
sensitivities = self._status.sensitivies sensitivities = self._status.sensitivies
baseline = None baseline = None
ratios = np.arange(0.1, 1, step=0.1) ratios = np.arange(0.1, 1, step=0.1)
for group in self.var_group.groups: for _collection in self.collections:
var_name = group[0][0] var_name = _collection.master_name
dims = group[0][1] dims = _collection.master_axis
if target_vars is not None and var_name not in target_vars: if target_vars is not None and var_name not in target_vars:
continue continue
if skip_vars is not None and var_name in skip_vars: if skip_vars is not None and var_name in skip_vars:
...@@ -282,7 +281,6 @@ class FilterPruner(Pruner): ...@@ -282,7 +281,6 @@ class FilterPruner(Pruner):
self.restore() self.restore()
ratios, pruned_flops = self.get_ratios_by_sensitivity( ratios, pruned_flops = self.get_ratios_by_sensitivity(
pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars) pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars)
_logger.debug("ratios: {}".format(ratios))
self.plan = self.prune_vars(ratios, FILTER_DIM) self.plan = self.prune_vars(ratios, FILTER_DIM)
self.plan._pruned_flops = pruned_flops self.plan._pruned_flops = pruned_flops
return self.plan return self.plan
...@@ -291,73 +289,60 @@ class FilterPruner(Pruner): ...@@ -291,73 +289,60 @@ class FilterPruner(Pruner):
if self.plan is not None: if self.plan is not None:
self.plan.restore(self.model) self.plan.restore(self.model)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, pruned_ratio, collection):
"""
{
var_name: {
'layer': sub_layer,
'var': variable,
'value': np.array([]),
'pruned_dims': [1],
}
}
"""
raise NotImplemented("cal_mask is not implemented") raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_dims, pruned_ratio, apply="impretive"): def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
""" """
Pruning a variable. Pruning a variable.
Parameters: Parameters:
var_name(str): The name of variable. var_name(str): The name of variable.
pruned_dims(list<int>): The axies to be pruned. For convolution with format [out_c, in_c, k, k], pruned_axis(int): The axis to be pruned. For convolution with format [out_c, in_c, k, k],
'axis=[0]' means pruning filters and 'axis=[0, 1]' means pruning kernels. 'axis=0' means pruning filters.
pruned_ratio(float): The ratio of pruned values in one variable. pruned_ratio(float): The ratio of pruned values in one variable.
apply(str): How to apply pruning plan to graph. It can be 'impretive', 'lazy' or None. None
means just returning an instance of 'PruningPlan' but not applying it to graph.
Returns: Returns:
plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'. plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'.
""" """
pruned_axis = pruned_axis[0] if isinstance(pruned_axis,
list) else pruned_axis
assert (isinstance(pruned_axis, int))
if var_name in self.skip_vars: if var_name in self.skip_vars:
_logger.warn( _logger.warn(
f"{var_name} is skiped beacause it is not support for pruning derectly." f"{var_name} is skiped beacause it is not supported for pruning directly."
) )
return return
if isinstance(pruned_dims, int): collection = self.collections.find_collection_by_master(var_name,
pruned_dims = [pruned_dims] pruned_axis)
group = self.var_group.find_group(var_name, pruned_dims)
_logger.debug("found group with {}: {}".format(var_name, group))
plan = PruningPlan(self.model.full_name) plan = PruningPlan(self.model.full_name)
group_dict = {} if collection is None:
for sub_layer in self.model.sublayers(): _logger.debug(
for param in sub_layer.parameters(include_sublayers=False): f"Can not find collection with master ['name': {var_name}, 'axis': {pruned_axis}]"
if param.name in group: )
group_dict[param.name] = group[param.name] return plan
# Varibales can be pruned on multiple axies. _logger.info(
for _item in group_dict[param.name]: f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
_item.update({ )
'layer': sub_layer,
'var': param, mask = self.cal_mask(pruned_ratio, collection)
'value': np.array(param.value().get_tensor()) for _detail in collection.all_pruning_details():
})
_logger.debug(f"set value of {param.name} into group")
mask = self.cal_mask(var_name, pruned_ratio, group_dict)
for _name in group_dict:
# Varibales can be pruned on multiple axies. # Varibales can be pruned on multiple axies.
for _item in group_dict[_name]: src_mask = copy.deepcopy(mask)
src_mask = copy.deepcopy(mask) var_shape = _detail.var.shape()
dims = _item['pruned_dims'] for tran in _detail.transform:
transforms = _item['transforms'] src_mask = self._transform_mask(src_mask, tran)
var_shape = _item['var'].shape current_mask = src_mask
if isinstance(dims, int): groups = _detail.op.attr('groups')
dims = [dims] if groups is None or groups == 1:
for trans in transforms: assert len(current_mask) == var_shape[
src_mask = self._transform_mask(src_mask, trans) _detail.
current_mask = src_mask axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
assert len(current_mask) == var_shape[dims[ plan.add(_detail.name,
0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}" PruningMask(_detail.axis, current_mask, pruned_ratio,
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) _detail.op))
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
...@@ -371,17 +356,8 @@ class FilterPruner(Pruner): ...@@ -371,17 +356,8 @@ class FilterPruner(Pruner):
target_start = transform['target_start'] target_start = transform['target_start']
target_end = transform['target_end'] target_end = transform['target_end']
target_len = transform['target_len'] target_len = transform['target_len']
stride = transform['stride']
mask = mask[src_start:src_end] mask = mask[src_start:src_end]
mask = mask.repeat(stride) if stride > 1 else mask
dst_mask = np.ones([target_len]) dst_mask = np.ones([target_len])
# for depthwise conv2d with:
# input shape: (1, 4, 32, 32)
# filter shape: (32, 1, 3, 3)
# groups: 4
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask)) expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand dst_mask[target_start:target_end] = list(mask) * expand
elif "stride" in transform: elif "stride" in transform:
......
...@@ -15,24 +15,38 @@ class FPGMFilterPruner(FilterPruner): ...@@ -15,24 +15,38 @@ class FPGMFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None): def __init__(self, model, inputs, sen_file=None):
super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file) super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, pruned_ratio, collection):
for _item in group[var_name]: var_name = collection.master_name
if _item['pruned_dims'] == [0]: pruned_axis = collection.master_axis
value = _item['value'] value = collection.values[var_name]
pruned_dims = _item['pruned_dims'] groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
dist_sum_list = [] dist_sum_list = []
for out_i in range(value.shape[0]): for out_i in range(value.shape[0]):
dist_sum = self.get_distance_sum(value, out_i) dist_sum = self.get_distance_sum(value, out_i)
dist_sum_list.append(dist_sum) dist_sum_list.append(dist_sum)
scores = np.array(dist_sum_list) scores = np.array(dist_sum_list)
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort() sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims] mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32") mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
return mask return mask.reshape(mask_shape)
def get_distance_sum(self, value, out_idx): def get_distance_sum(self, value, out_idx):
w = value.view() w = value.view()
......
...@@ -16,19 +16,32 @@ class L1NormFilterPruner(FilterPruner): ...@@ -16,19 +16,32 @@ class L1NormFilterPruner(FilterPruner):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file) model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, pruned_ratio, collection):
for _item in group[var_name]: var_name = collection.master_name
if _item['pruned_dims'] == [0]: pruned_axis = collection.master_axis
value = _item['value'] value = collection.values[var_name]
pruned_dims = _item['pruned_dims'] groups = 1
reduce_dims = [ for _detail in collection.all_pruning_details():
i for i in range(len(value.shape)) if i not in pruned_dims assert (isinstance(_detail.axis, int))
] if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims)) l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1:
l1norm = l1norm.reshape([groups, -1])
l1norm = np.mean(l1norm, axis=1)
sorted_idx = l1norm.argsort() sorted_idx = l1norm.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32") mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
return mask return mask.reshape(mask_shape)
...@@ -16,22 +16,32 @@ class L2NormFilterPruner(FilterPruner): ...@@ -16,22 +16,32 @@ class L2NormFilterPruner(FilterPruner):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file) model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group): def cal_mask(self, pruned_ratio, collection):
# find information of pruning on output channels var_name = collection.master_name
for _item in group[var_name]: pruned_axis = collection.master_axis
if _item['pruned_dims'] == [0]: value = collection.values[var_name]
value = _item['value'] groups = 1
pruned_dims = _item['pruned_dims'] for _detail in collection.all_pruning_details():
reduce_dims = [ assert (isinstance(_detail.axis, int))
i for i in range(len(value.shape)) if i not in pruned_dims if _detail.axis == 1:
] _groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
# scores = np.mean(np.abs(value), axis=tuple(reduce_dims)) groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort() sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio)) pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32") mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
return mask return mask.reshape(mask_shape)
...@@ -39,11 +39,12 @@ class Pruner(object): ...@@ -39,11 +39,12 @@ class Pruner(object):
Args: Args:
ratios(dict<str, float>): The key is the name of variable to be pruned and the ratios(dict<str, float>): The key is the name of variable to be pruned and the
value is the pruned ratio. value is the pruned ratio.
axis(list): The dimensions to be pruned on. axis(int): The dimension to be pruned on.
Returns: Returns:
plan(PruningPlan): The pruning plan. plan(PruningPlan): The pruning plan.
""" """
axis = axis[0] if isinstance(axis, list) else axis
global_plan = PruningPlan(self.model.full_name) global_plan = PruningPlan(self.model.full_name)
for var, ratio in ratios.items(): for var, ratio in ratios.items():
if not global_plan.contains(var, axis): if not global_plan.contains(var, axis):
......
...@@ -10,27 +10,17 @@ __all__ = ['PruningPlan', 'PruningMask'] ...@@ -10,27 +10,17 @@ __all__ = ['PruningPlan', 'PruningMask']
class PruningMask(): class PruningMask():
def __init__(self, dims, mask, ratio): def __init__(self, dims, mask, ratio, op):
assert (isinstance(dims, int))
self._dims = dims self._dims = dims
self._mask = mask self._mask = mask
self._pruned_ratio = ratio self._pruned_ratio = ratio
self._op = op
@property @property
def dims(self): def dims(self):
return self._dims return self._dims
@dims.setter
def dims(self, value):
if not isinstance(value, collections.Iterator):
raise ValueError(
"The dims of PruningMask must be instance of collections.Iterator."
)
if self._mask is not None:
assert len(self._mask.shape) == len(
value
), "The length of value must be same with length of mask's shape in current PruningMask instance."
self._dims = list(value)
@property @property
def mask(self): def mask(self):
return self._mask return self._mask
...@@ -128,8 +118,7 @@ class PruningPlan(): ...@@ -128,8 +118,7 @@ class PruningPlan():
_logger.debug("Backup values of {} into buffers.". _logger.debug("Backup values of {} into buffers.".
format(param.name)) format(param.name))
expand_mask_shape = [1] * len(value.shape) expand_mask_shape = [1] * len(value.shape)
for i in dims: expand_mask_shape[dims] = value.shape[dims]
expand_mask_shape[i] = value.shape[i]
_logger.debug("Expanded mask shape: {}".format( _logger.debug("Expanded mask shape: {}".format(
expand_mask_shape)) expand_mask_shape))
expand_mask = mask.reshape(expand_mask_shape).astype( expand_mask = mask.reshape(expand_mask_shape).astype(
...@@ -158,13 +147,25 @@ class PruningPlan(): ...@@ -158,13 +147,25 @@ class PruningPlan():
if param.name in self._masks: if param.name in self._masks:
for _mask in self._masks[param.name]: for _mask in self._masks[param.name]:
dims = _mask.dims dims = _mask.dims
assert (isinstance(dims, int))
mask = _mask.mask mask = _mask.mask
assert len( bool_mask = np.array(mask).astype(bool)
dims
) == 1, "Imperative mode only support for pruning on one dimension, but get dims {} when pruning parameter {}".format(
dims, param.name)
t_value = param.value().get_tensor() t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len(
value.shape) == 4:
filter_size = value.shape[1]
except_num = np.sum(bool_mask)
assert (except_num % filter_size == 0)
new_groups = int(except_num / filter_size)
sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = new_groups
_logger.info("change groups from {} to {} for {}.".
format(groups, new_groups, param.name))
continue
# The name of buffer can not contains "." # The name of buffer can not contains "."
backup_name = param.name.replace(".", "_") + "_backup" backup_name = param.name.replace(".", "_") + "_backup"
if backup_name not in sub_layer._buffers: if backup_name not in sub_layer._buffers:
...@@ -172,9 +173,8 @@ class PruningPlan(): ...@@ -172,9 +173,8 @@ class PruningPlan():
paddle.to_tensor(value)) paddle.to_tensor(value))
_logger.debug("Backup values of {} into buffers.". _logger.debug("Backup values of {} into buffers.".
format(param.name)) format(param.name))
bool_mask = np.array(mask).astype(bool)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value) lambda data: data[bool_mask], dims, value)
p = t_value._place() p = t_value._place()
if p.is_cpu_place(): if p.is_cpu_place():
place = paddle.CPUPlace() place = paddle.CPUPlace()
...@@ -186,18 +186,6 @@ class PruningPlan(): ...@@ -186,18 +186,6 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id()) place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place) t_value.set(pruned_value, place)
if isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D
) and sub_layer._groups > 1 and len(param.shape) == 4:
assert param.shape[
1] == 1, "It just supports depthwise conv2d when groups > 1."
new_groups = int(bool_mask.sum() *
sub_layer._groups / len(bool_mask))
_logger.debug(
"Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups, new_groups))
sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = new_groups
# for training # for training
if param.trainable: if param.trainable:
......
...@@ -3,15 +3,15 @@ import logging ...@@ -3,15 +3,15 @@ import logging
import paddle import paddle
from paddle.fluid.dygraph import TracedLayer from paddle.fluid.dygraph import TracedLayer
from paddleslim.core import GraphWrapper, dygraph2program from paddleslim.core import GraphWrapper, dygraph2program
from paddleslim.prune import collect_convs from paddleslim.prune import PruningCollections
from paddleslim.common import get_logger from paddleslim.common import get_logger
__all__ = ["VarGroup"] __all__ = ["DygraphPruningCollections"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
class VarGroup(): class DygraphPruningCollections(PruningCollections):
""" """
A tool used to parse dygraph and store information of variables' relationship. A tool used to parse dygraph and store information of variables' relationship.
Args: Args:
...@@ -20,40 +20,29 @@ class VarGroup(): ...@@ -20,40 +20,29 @@ class VarGroup():
""" """
def __init__(self, model, inputs): def __init__(self, model, inputs):
self.groups = []
self._parse_model(model, inputs)
def _to_dict(self, group):
ret = {}
for _name, _axis, _transforms in group:
if isinstance(_axis, int):
_axis = [_axis]
if _name not in ret:
ret[_name] = []
# Variable can be pruned on multiple axies.
ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms})
return ret
def find_group(self, var_name, axis):
for group in self.groups:
for _name, _axis, _stride in group:
if isinstance(_axis, int):
_axis = [_axis]
if _name == var_name and _axis == axis:
return self._to_dict(group)
def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training. # model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs) program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(program) graph = GraphWrapper(program)
visited = {} params = [
for name, param in model.named_parameters(): _param.name for _param in model.parameters()
group = collect_convs([param.name], graph, if len(_param.shape) == 4
visited)[0] # [(name, axis, pruned_idx)] ]
if len(group) > 0: self._collections = self.create_pruning_collections(params, graph)
self.groups.append(group) _logger.info("Found {} collections.".format(len(self._collections)))
_logger.info("Found {} groups.".format(len(self.groups)))
_name2values = {}
for param in model.parameters():
_name2values[param.name] = np.array(param.value().get_tensor())
for collection in self._collections:
collection.values = _name2values
def find_collection_by_master(self, var_name, axis):
for _collection in self._collections:
if _collection.master['name'] == var_name and _collection.master[
'axis'] == axis:
return _collection
def __str__(self): def __str__(self):
return "\n".join([str(group) for group in self.groups]) return "\n".join(
[str(_collection) for _collection in self._collections])
...@@ -19,17 +19,16 @@ from .auto_pruner import * ...@@ -19,17 +19,16 @@ from .auto_pruner import *
from ..prune import auto_pruner from ..prune import auto_pruner
from .sensitive import * from .sensitive import *
from ..prune import sensitive from ..prune import sensitive
from .prune_walker import * from .prune_worker import *
from ..prune import prune_walker from ..prune import prune_worker
from .prune_io import * from .prune_io import *
from ..prune import prune_io from ..prune import prune_io
from .group_param import *
from ..prune import group_param
from .criterion import * from .criterion import *
from ..prune import criterion from ..prune import criterion
from .collections import *
from ..prune import collections
from .unstructured_pruner import * from .unstructured_pruner import *
from ..prune import unstructured_pruner from ..prune import unstructured_pruner
from .idx_selector import * from .idx_selector import *
from ..prune import idx_selector from ..prune import idx_selector
__all__ = [] __all__ = []
...@@ -37,9 +36,9 @@ __all__ = [] ...@@ -37,9 +36,9 @@ __all__ = []
__all__ += pruner.__all__ __all__ += pruner.__all__
__all__ += auto_pruner.__all__ __all__ += auto_pruner.__all__
__all__ += sensitive.__all__ __all__ += sensitive.__all__
__all__ += prune_walker.__all__ __all__ += prune_worker.__all__
__all__ += prune_io.__all__ __all__ += prune_io.__all__
__all__ += group_param.__all__
__all__ += criterion.__all__ __all__ += criterion.__all__
__all__ += unstructured_pruner.__all__ __all__ += unstructured_pruner.__all__
__all__ += idx_selector.__all__ __all__ += idx_selector.__all__
__all__ += collections.__all__
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from ..core import GraphWrapper, VarWrapper
from ..common import get_logger
from .prune_worker import PRUNE_WORKER, UnsupportOpError
__all__ = [
'PruningDetails', 'PruningCollection', 'PruningCollections',
'StaticPruningCollections'
]
_logger = get_logger(__name__, level=logging.INFO)
class PruningDetails(object):
"""
The description of one pruning operation.
Args:
var(VarWrapper): The variable to be pruned.
axis(int): The axis to be pruned on.
transform(dict): Information used to convert pruned indices of master
tensor to indices of current tensor.
op(OpWrapper): The operator with current tensor as input.
is_parameter(bool): whether the tensor is parameter. Default: True.
"""
def __init__(self, var, axis, transform, op, is_parameter=True):
assert (isinstance(var, VarWrapper),
"name should be VarWrapper, but get type = ".format(type(var)))
assert (isinstance(axis, int))
self.name = var.name()
self.var = var
self.axis = axis
self.transform = transform
self.op = op
self.is_parameter = is_parameter
def __eq__(self, other):
if self.name != other.name:
return False
if self.axis != other.axis:
return False
if self.transform != other.transform:
return False
return True
class PruningCollection(object):
"""
A group of pruning operations.
conv1-->conv2-->batch_norm
For the network defined above, if weight of conv1 is pruned on 0-axis,
weight of'conv2' should be pruned on 1-axis. The pruning operations on 0-axis of
'conv1' and those on 1-axis of 'conv2' is a collection. And the {'name': conv1.weight_name, 'axis': 0}
is the master of current collection.
Args:
master(dict): The master pruning operation.
"""
def __init__(self, master=None):
self._master = master
self.master_name = master['name']
self.master_axis = master['axis']
self._nodes = {}
def variables(self):
"""
Get all tensors to be pruned in current collection.
Returns:
list<str>: Names of tensor to be pruned.
"""
return list(self._nodes.keys())
def add(self, node):
"""
Add a pruning operation into current collention.
Args:
node(PruningDetails): Pruning operation to be added into current collection.
"""
assert (isinstance(node, PruningDetails))
# the first added pruning operation will be master.
self._master = {
"name": node.name,
"axis": node.aixs
} if self._master is None else self._master
if node.name not in self._nodes:
self._nodes[node.name] = []
if node not in self._nodes[node.name]:
self._nodes[node.name].append(node)
@property
def master(self):
return self._master
def all_pruning_details(self):
"""
Get all pruning operations in current collection.
Returns:
list<PruningDetails>: Pruning operations.
"""
ret = []
for _items in self._nodes.values():
ret.extend(_items)
return ret
class PruningCollections(object):
def __init__(self):
self._collections = None
def __iter__(self):
return iter(self._collections)
def create_pruning_collections(self, params, graph, skip_stranger=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
Returns:
list<Group>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
visited = {}
collections = []
unsupported_warnings = set()
for _param in params:
pruned_params = []
param = graph.var(_param)
if param is None:
_logger.warning(
f"Couldn't find relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correct mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
continue
target_op = param.outputs()[0]
if target_op.type() == 'conditional_block':
for op in param.outputs():
if op.type() in PRUNE_WORKER._module_dict.keys():
cls = PRUNE_WORKER.get(op.type())
worker = cls(op,
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
if cls is None:
_logger.warning("No worker for operator: {}".format(
target_op.type()))
continue
worker = cls(target_op,
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
try:
visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[])
except UnsupportOpError as e:
visited.clear()
visited.update(visited_backup)
unsupported_warnings.add(e.args)
else:
if len(pruned_params) != 0:
collection = PruningCollection(master=({
"name": param.name(),
"axis": 0
}))
for _param, _axis, _transform, _op in pruned_params:
collection.add(
PruningDetails(_param, _axis, _transform, _op))
collections.append(collection)
for warn in unsupported_warnings:
_logger.warning(warn)
self._collections = collections
return self._collections
class StaticPruningCollections(PruningCollections):
def __init__(self, params, graph, skip_stranger=True):
super(StaticPruningCollections, self).__init__()
self._collections = self.create_pruning_collections(
params, graph, skip_stranger=skip_stranger)
...@@ -27,7 +27,7 @@ CRITERION = Registry('criterion') ...@@ -27,7 +27,7 @@ CRITERION = Registry('criterion')
@CRITERION.register @CRITERION.register
def l1_norm(group, graph): def l1_norm(group, values, graph):
"""Compute l1-norm scores of parameter on given axis. """Compute l1-norm scores of parameter on given axis.
This function return a list of parameters' l1-norm scores on given axis. This function return a list of parameters' l1-norm scores on given axis.
...@@ -35,28 +35,44 @@ def l1_norm(group, graph): ...@@ -35,28 +35,44 @@ def l1_norm(group, graph):
and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`. and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`.
Args: Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight group(Group): A group of pruning operations.
while the others are parameters affected by pruning the first one. Each parameter in group values(dict): The key is the name of tensor in group, and the value of dict is the
is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and values of tensor.
and `values` is the values of parameter and `axis` is the axis reducing on pruning on. graph(GraphWrapper): The graph stores structure information of network.
Returns: Returns:
list: A list of tuple storing l1-norm on given axis. dict: The key is name of tensor, the value is a dict
with axis as key and l1-norm scores as value.
""" """
scores = [] scores = {}
for name, value, axis, pruned_idx in group:
for pruning_details in group.all_pruning_details():
name = pruning_details.name
if name not in values:
_logger.warning(
"The value of tensor '{}' is not found, so it will not be used when evaluating importance of pruned structures.".
format(name))
continue
value = values[name]
axis = pruning_details.axis
reduce_dims = [i for i in range(len(value.shape)) if i != axis] reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims)) score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score, pruned_idx)) if name not in scores:
scores[name] = {}
scores[name][axis] = score
return scores return scores
@CRITERION.register @CRITERION.register
def geometry_median(group, graph): def geometry_median(group, values, graph):
scores = [] name = group.master["name"]
name, value, axis, _ = group[0] axis = group.master["axis"]
assert (len(value.shape) == 4) if name not in values:
_logger.warning("The value of tensor '{}' is not found.")
return None
value = values[name]
assert (len(value.shape) == 4,
"geometry_median only support for weight of conv2d.")
def get_distance_sum(value, out_idx): def get_distance_sum(value, out_idx):
w = value.view() w = value.view()
...@@ -73,31 +89,26 @@ def geometry_median(group, graph): ...@@ -73,31 +89,26 @@ def geometry_median(group, graph):
tmp = np.array(dist_sum_list) tmp = np.array(dist_sum_list)
for name, value, axis, idx in group: scores = {}
scores.append((name, axis, tmp, idx)) for pruning_details in group.all_pruning_details():
name = pruning_details.name
axis = pruning_details.axis
if name not in scores:
scores[name] = {}
scores[name][axis] = tmp
return scores return scores
@CRITERION.register @CRITERION.register
def bn_scale(group, graph): def bn_scale(group, values, graph):
"""Compute l1-norm scores of parameter on given axis. """Compute scores by scales of batch_norm layer.
This function return a list of parameters' l1-norm scores on given axis.
Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name
and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and
and `values` is the values of parameter and `axis` is the axis reducing on pruning on.
Returns:
list: A list of tuple storing l1-norm on given axis.
""" """
assert (isinstance(graph, GraphWrapper)) assert (isinstance(graph, GraphWrapper))
# step1: Get first convolution # step1: Get first convolution
conv_weight, value, axis, _ = group[0] conv_weight = group.master["name"]
axis = group.master["axis"]
value = values[conv_weight]
param_var = graph.var(conv_weight) param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0] conv_op = param_var.outputs()[0]
...@@ -111,12 +122,16 @@ def bn_scale(group, graph): ...@@ -111,12 +122,16 @@ def bn_scale(group, graph):
# steps3: Find scale of bn # steps3: Find scale of bn
score = None score = None
for name, value, aixs, _ in group: if bn_scale_param not in values:
if bn_scale_param == name: raise SystemExit("Can't find values of scales in BatchNorm.")
score = np.abs(value.reshape([-1])) value = values[bn_scale_param]
score = np.abs(value.reshape([-1]))
scores = []
for name, value, axis, idx in group: scores = {}
scores.append((name, axis, score, idx)) for pruning_details in group.all_pruning_details():
name = pruning_details.name
axis = pruning_details.axis
if name not in scores:
scores[name] = {}
scores[name][axis] = score
return scores return scores
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from ..core import GraphWrapper
from ..common import get_logger
from .prune_walker import PRUNE_WORKER
__all__ = ["collect_convs"]
_logger = get_logger(__name__, level=logging.INFO)
def collect_convs(params, graph, visited={}):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
groups = []
for _param in params:
pruned_params = []
param = graph.var(_param)
if param is None:
_logger.warning(
f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
groups.append([])
continue
target_op = param.outputs()[0]
if target_op.type() == 'conditional_block':
for op in param.outputs():
if op.type() in PRUNE_WORKER._module_dict.keys():
cls = PRUNE_WORKER.get(op.type())
walker = cls(op,
pruned_params=pruned_params,
visited=visited)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
if cls is None:
_logger.info("No walker for operator: {}".format(target_op.type(
)))
groups.append(pruned_params)
continue
walker = cls(target_op,
pruned_params=pruned_params,
visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, pruned_idx in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis, pruned_idx))
if not repeat_group:
uniq_groups.append(simple_group)
return uniq_groups
...@@ -26,75 +26,80 @@ IDX_SELECTOR = Registry('idx_selector') ...@@ -26,75 +26,80 @@ IDX_SELECTOR = Registry('idx_selector')
@IDX_SELECTOR.register @IDX_SELECTOR.register
def default_idx_selector(group, ratio): def default_idx_selector(group, scores, ratios):
"""Get the pruned indexes by given ratio. """Get the pruned indices by scores of master tensor.
This function return a list of parameters' pruned indexes on given axis. This function return a list of parameters' pruned indices on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name Each element of list is a tuple with format (name, axis, indices)
and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. in which 'name' is parameter's name and 'axis' is the axis pruning on and
`indices` is indices to be pruned.
Args: Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight group(Group): A group of pruning operations.
while the others are parameters affected by pruning the first one. Each parameter in group scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value.
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio.
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
Returns: Returns:
list: pruned indexes list: pruned indices with format (name, axis, pruned_indices).
""" """
name, axis, score, _ = group[ # sort channels by the master convolution's score
0] # sort channels by the first convolution's score name = group.master["name"]
axis = group.master["axis"]
score = scores[name][axis]
# get max convolution groups attribution
max_groups = 1
for pruning_details in group.all_pruning_details():
groups = pruning_details.op.attr("groups")
if groups is not None and groups > max_groups:
max_groups = groups
if max_groups > 1:
score = score.reshape([max_groups, -1])
group_size = score.shape[1]
# get score for each group of channels
score = np.mean(score, axis=1)
sorted_idx = score.argsort() sorted_idx = score.argsort()
ratio = ratios[name]
pruned_num = int(round(len(sorted_idx) * ratio)) pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
idxs = [] # convert indices of channel groups to indices of channels.
for name, axis, score, transforms in group: if max_groups > 1:
idxs.append((name, axis, pruned_idx, transforms)) correct_idx = []
return idxs for idx in pruned_idx:
for offset in range(group_size):
correct_idx.append(idx * group_size + offset)
pruned_idx = correct_idx[:]
ret = []
for _pruning_details in group.all_pruning_details():
ret.append((_pruning_details.name, _pruning_details.axis, pruned_idx,
_pruning_details.transform))
return ret
@IDX_SELECTOR.register @IDX_SELECTOR.register
def optimal_threshold(group, ratio): def optimal_threshold(group, scores, ratios):
"""Get the pruned indexes by given ratio. """Get the pruned indices by scores of master tensor.
This function return a list of parameters' pruned indexes on given axis. This function return a list of parameters' pruned indices on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name Each element of list is a tuple with format (name, axis, indices)
and 'axis' is the axis pruning on and `indexes` is indexes to be pruned. in which 'name' is parameter's name and 'axis' is the axis pruning on and
`indices` is indices to be pruned.
Args: Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight group(Group): A group of pruning operations.
while the others are parameters affected by pruning the first one. Each parameter in group scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value.
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio.
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
Returns: Returns:
list: pruned indices with format (name, axis, pruned_indices).
list: pruned indexes
""" """
name, axis, score, _ = group[ # sort channels by the master tensor
0] # sort channels by the first convolution's score name = group.master["name"]
axis = group.master["axis"]
score = scores[name][axis]
ratio = ratios[name]
score[score < 1e-18] = 1e-18 score[score < 1e-18] = 1e-18
score_sorted = np.sort(score) score_sorted = np.sort(score)
...@@ -110,6 +115,7 @@ def optimal_threshold(group, ratio): ...@@ -110,6 +115,7 @@ def optimal_threshold(group, ratio):
pruned_idx = np.squeeze(np.argwhere(score < th)) pruned_idx = np.squeeze(np.argwhere(score < th))
idxs = [] idxs = []
for name, axis, score, transforms in group: for _pruning_details in group.all_pruning_details():
idxs.append((name, axis, pruned_idx, transforms)) idxs.append((_pruning_details.name, _pruning_details.axis, pruned_idx,
_pruning_details.transform))
return idxs return idxs
...@@ -12,35 +12,63 @@ ...@@ -12,35 +12,63 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import logging import logging
import numpy as np import numpy as np
from ..core import Registry from ..core import Registry
from ..common import get_logger from ..common import get_logger
__all__ = ["PRUNE_WORKER", "conv2d"] __all__ = ["PRUNE_WORKER", "conv2d", "UnsupportOpError"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker') PRUNE_WORKER = Registry('prune_worker')
SKIP_OPS = ["conditional_block"] SKIPPED_OPS = ['shape', 'reduce_mean']
# operators in OPS_UNCHANGE_SHAPE will be visited by default worker
# who keep shape of output same with shape of input.
OPS_UNCHANGE_SHAPE = os.getenv('OPS_UNCHANGE_SHAPE', None)
OPS_UNCHANGE_SHAPE = [] if OPS_UNCHANGE_SHAPE is None else OPS_UNCHANGE_SHAPE.strip(
).split(",")
OPS_UNCHANGE_SHAPE += [
'nearest_interp_v2',
'roi_align',
'sigmoid',
'swish',
'pad3d',
'bilinear_interp_v2',
'dropout',
'cast',
'hard_swish',
'hard_sigmoid',
]
class UnsupportOpError(Exception):
pass
class PruneWorker(object): class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}): def __init__(self, op, pruned_params, visited, skip_stranger=True):
""" """
A wrapper of operator used to infer the information of all the related variables. A wrapper of operator used to infer the information of all the related variables.
Args: Args:
op(Operator): The operator to be pruned. op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker. pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name. visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
Return: A instance of PruneWalker. Return: A instance of PruneWorker.
""" """
self.op = op self.op = op
self.pruned_params = pruned_params self.pruned_params = pruned_params
self.visited = visited self.visited = visited
self.skip_stranger = skip_stranger
self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None)
self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip(
).split(",")
def prune(self, var, pruned_axis, pruned_idx): def prune(self, var, pruned_axis, pruned_idx):
""" """
...@@ -49,7 +77,7 @@ class PruneWorker(object): ...@@ -49,7 +77,7 @@ class PruneWorker(object):
Args: Args:
var(Variable): The root variable of searching. It can be the input or output of current operator. var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable. pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable. pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable.
""" """
if self._visit(var, pruned_axis): if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx) self._prune(var, pruned_axis, pruned_idx)
...@@ -82,29 +110,36 @@ class PruneWorker(object): ...@@ -82,29 +110,36 @@ class PruneWorker(object):
return return
if visited is not None: if visited is not None:
self.visited = visited self.visited = visited
if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
cls = PRUNE_WORKER.get(op.type()) cls = PRUNE_WORKER.get(op.type())
if cls is None: if cls is None:
if op.type() in SKIP_OPS: if op.type() in SKIPPED_OPS:
_logger.warn("Skip operator [{}]".format(op.type()))
return return
if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger:
cls = PRUNE_WORKER.get("default_worker")
else:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
# _logger.warn( _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.". format(self.op, op, pruned_axis, var.name(), pruned_idx))
# format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
_logger.debug( _logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n" f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
) )
walker = cls(op, pruned_params=self.pruned_params, visited=self.visited) worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
walker.prune(var, pruned_axis, pruned_idx) worker.prune(var, pruned_axis, pruned_idx)
def append_pruned_vars(self, var, axis, transforms):
self.pruned_params.append((var, axis, transforms, self.op))
@PRUNE_WORKER.register @PRUNE_WORKER.register
class conv2d(PruneWorker): class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(conv2d, self).__init__(op, pruned_params, visited) super(conv2d, self).__init__(op, pruned_params, visited, skip_stranger)
def _is_depthwise_conv(self, op): def _is_depthwise_conv(self, op):
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
...@@ -121,15 +156,17 @@ class conv2d(PruneWorker): ...@@ -121,15 +156,17 @@ class conv2d(PruneWorker):
num_filters % num_channels == 0) num_filters % num_channels == 0)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if self._is_depthwise_conv(self.op): if self._is_depthwise_conv(self.op):
_logger.debug(f"Meet conv2d who is depthwise conv2d actually.") _logger.debug(f"Meet conv2d who is depthwise conv2d actually.")
walker = depthwise_conv2d( worker = depthwise_conv2d(
self.op, self.pruned_params, visited=self.visited) self.op,
walker._prune(var, pruned_axis, pruned_idx) self.pruned_params,
return visited=self.visited,
skip_stranger=self.skip_stranger)
return worker._prune(var, pruned_axis, pruned_idx)
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
...@@ -137,56 +174,49 @@ class conv2d(PruneWorker): ...@@ -137,56 +174,49 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1) self.append_pruned_vars(filter_var, 1, pruned_idx)
self.pruned_params.append((filter_var, 1, pruned_idx)) if groups is None or groups == 1:
for op in filter_var.outputs(): self._visit_and_search(filter_var, 1, pruned_idx)
self._prune_op(op, filter_var, 1, pruned_idx)
elif var in self.op.inputs("Filter"): elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1] assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx)) self.append_pruned_vars(var, pruned_axis, pruned_idx)
for op in var.outputs(): if groups is None or groups == 1 or pruned_axis == 0:
self._prune_op(op, var, pruned_axis, pruned_idx) self._visit_and_search(var, pruned_axis, pruned_idx)
if pruned_axis == 0: if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append( self.append_pruned_vars(
(self.op.inputs("Bias"), channel_axis, pruned_idx)) self.op.inputs("Bias"), channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis) self._visit_and_search(output_var, channel_axis, pruned_idx)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 1: elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0] input_var = self.op.inputs("Input")[0]
self._visit(input_var, channel_axis) self._visit_and_search(input_var, channel_axis, pruned_idx)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
elif var in self.op.outputs("Output"): elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx)
self.pruned_params.append((filter_var, 0, pruned_idx))
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx) self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append( self.append_pruned_vars(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx)) self.op.inputs("Bias")[0], channel_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class conv2d_transpose(PruneWorker): class conv2d_transpose(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(conv2d_transpose, self).__init__(op, pruned_params, visited) super(conv2d_transpose, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
...@@ -198,7 +228,7 @@ class conv2d_transpose(PruneWorker): ...@@ -198,7 +228,7 @@ class conv2d_transpose(PruneWorker):
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx)) self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx) self._prune_op(op, filter_var, 0, pruned_idx)
...@@ -212,14 +242,14 @@ class conv2d_transpose(PruneWorker): ...@@ -212,14 +242,14 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1) self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx)) self.append_pruned_vars(filter_var, 1, pruned_idx)
for op in filter_var.outputs(): for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx) self._prune_op(op, filter_var, 1, pruned_idx)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append( self.append_pruned_vars(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx)) self.op.inputs("Bias")[0], channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs() next_ops = output_var.outputs()
...@@ -229,8 +259,9 @@ class conv2d_transpose(PruneWorker): ...@@ -229,8 +259,9 @@ class conv2d_transpose(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class batch_norm(PruneWorker): class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(batch_norm, self).__init__(op, pruned_params, visited) super(batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Y")) and ( if (var not in self.op.outputs("Y")) and (
...@@ -248,7 +279,7 @@ class batch_norm(PruneWorker): ...@@ -248,7 +279,7 @@ class batch_norm(PruneWorker):
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
for op in param_var.outputs(): for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx) self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx)) self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Y")[0] out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis) self._visit(out_var, pruned_axis)
...@@ -259,13 +290,15 @@ class batch_norm(PruneWorker): ...@@ -259,13 +290,15 @@ class batch_norm(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class sync_batch_norm(batch_norm): class sync_batch_norm(batch_norm):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(sync_batch_norm, self).__init__(op, pruned_params, visited) super(sync_batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger)
class elementwise_op(PruneWorker): class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_op, self).__init__(op, pruned_params, visited) super(elementwise_op, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
axis = self.op.attr("axis") axis = self.op.attr("axis")
...@@ -286,7 +319,7 @@ class elementwise_op(PruneWorker): ...@@ -286,7 +319,7 @@ class elementwise_op(PruneWorker):
# for bias # for bias
if name == "Y" and actual_axis >= 0 and not ( if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1): len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.pruned_params.append((in_var, actual_axis, pruned_idx)) self.append_pruned_vars(in_var, actual_axis, pruned_idx)
self._visit_and_search(in_var, actual_axis, pruned_idx) self._visit_and_search(in_var, actual_axis, pruned_idx)
else: else:
...@@ -301,8 +334,7 @@ class elementwise_op(PruneWorker): ...@@ -301,8 +334,7 @@ class elementwise_op(PruneWorker):
if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and
in_var.shape()[0] == 1): in_var.shape()[0] == 1):
self.pruned_params.append( self.append_pruned_vars(in_var, y_pruned_axis, pruned_idx)
(in_var, y_pruned_axis, pruned_idx))
self._visit_and_search(in_var, y_pruned_axis, pruned_idx) self._visit_and_search(in_var, y_pruned_axis, pruned_idx)
elif var in self.op.inputs("Y"): elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
...@@ -318,26 +350,30 @@ class elementwise_op(PruneWorker): ...@@ -318,26 +350,30 @@ class elementwise_op(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class elementwise_add(elementwise_op): class elementwise_add(elementwise_op):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_add, self).__init__(op, pruned_params, visited) super(elementwise_add, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class elementwise_sub(elementwise_op): class elementwise_sub(elementwise_op):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_sub, self).__init__(op, pruned_params, visited) super(elementwise_sub, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class elementwise_mul(elementwise_op): class elementwise_mul(elementwise_op):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_mul, self).__init__(op, pruned_params, visited) super(elementwise_mul, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class activation(PruneWorker): class activation(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(activation, self).__init__(op, pruned_params, visited) super(activation, self).__init__(op, pruned_params, visited,
skip_stranger)
self.input_name = "X" self.input_name = "X"
self.output_name = "Out" self.output_name = "Out"
...@@ -351,9 +387,10 @@ class activation(PruneWorker): ...@@ -351,9 +387,10 @@ class activation(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class default_walker(PruneWorker): class default_worker(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(default_walker, self).__init__(op, pruned_params, visited) super(default_worker, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs(): if var in self.op.all_outputs():
...@@ -367,59 +404,62 @@ class default_walker(PruneWorker): ...@@ -367,59 +404,62 @@ class default_walker(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class uniform_random_batch_size_like(activation): class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(uniform_random_batch_size_like, self).__init__(op, pruned_params, super(uniform_random_batch_size_like, self).__init__(
visited) op, pruned_params, visited, skip_stranger)
self.input_name = "Input" self.input_name = "Input"
self.output_name = "Out" self.output_name = "Out"
@PRUNE_WORKER.register @PRUNE_WORKER.register
class bilinear_interp(activation): class bilinear_interp(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(bilinear_interp, self).__init__(op, pruned_params, visited) super(bilinear_interp, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class nearest_interp(activation): class nearest_interp(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(nearest_interp, self).__init__(op, pruned_params, visited) super(nearest_interp, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class relu(activation): class relu(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(relu, self).__init__(op, pruned_params, visited) super(relu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class leaky_relu(activation): class leaky_relu(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(leaky_relu, self).__init__(op, pruned_params, visited) super(leaky_relu, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class floor(activation): class floor(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(floor, self).__init__(op, pruned_params, visited) super(floor, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class relu6(activation): class relu6(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(relu6, self).__init__(op, pruned_params, visited) super(relu6, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class pool2d(activation): class pool2d(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(pool2d, self).__init__(op, pruned_params, visited) super(pool2d, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class sum(PruneWorker): class sum(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(sum, self).__init__(op, pruned_params, visited) super(sum, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
...@@ -440,10 +480,46 @@ class sum(PruneWorker): ...@@ -440,10 +480,46 @@ class sum(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class split(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(split, self).__init__(op, pruned_params, visited, skip_stranger)
self.in_var = op.inputs("X")[0]
self.out_vars = op.outputs("Out")
self.axis = op.attr("axis")
self.num = op.attr("num")
def _prune(self, var, pruned_axis, transforms):
if var == self.in_var:
if pruned_axis != self.axis:
for out_var in self.out_vars:
self._visit_and_search(out_var, pruned_axis, transforms)
else:
raise UnsupportOpError(
"Unsupport pruning input of split operator directly.")
elif var in self.out_vars:
if pruned_axis != self.axis:
self._visit_and_search(self.in_var, pruned_axis, transforms)
else:
trans = {
"src_start": 0,
"src_end": var.shape()[pruned_axis],
"target_start": 0,
"target_end": self.in_var.shape()[pruned_axis],
"target_len": self.in_var.shape()[pruned_axis]
}
self._visit_and_search(self.in_var, pruned_axis,
transforms + [trans])
for out_var in self.out_vars:
if var != out_var:
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class concat(PruneWorker): class concat(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(concat, self).__init__(op, pruned_params, visited) super(concat, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms): def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis") axis = self.op.attr("axis")
...@@ -513,52 +589,56 @@ class concat(PruneWorker): ...@@ -513,52 +589,56 @@ class concat(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker): class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited) super(depthwise_conv2d, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms): def _prune(self, var, pruned_axis, transforms):
assert var not in self.op.inputs(
"Filter"), "Unsupport for pruning depthwise conv2d directly." _filter = self.op.inputs("Filter")[0]
assert var not in self.op.outputs( _out = self.op.outputs("Output")[0]
"Output" _in_var = self.op.inputs("Input")[0]
), "Unsupport for pruning output of depthwise conv2d directly."
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
if var in self.op.inputs("Input"):
if var == _in_var:
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis) pruned_axis)
# pruning number of filters
groups = var.shape()[channel_axis] self.append_pruned_vars(_filter, 0, transforms)
filter_var = self.op.inputs("Filter")[0] # kernel_number * groups will be pruned by reducing groups
transform = { self.append_pruned_vars(_filter, 1, transforms)
'src_start': 0, self._visit_and_search(_filter, 0, transforms)
'src_end': var.shape()[pruned_axis], # It will not pruning number of kernels in depthwise conv2d,
'target_start': 0, # so it is not neccesary to search succeed operators.
'target_end': filter_var.shape()[0], # self._visit_and_search(_filter, 1, transforms)
'target_len': filter_var.shape()[0], self._visit(_filter, 1)
'stride': 1 self._visit_and_search(_out, channel_axis, transforms)
} elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
self.pruned_params.append((filter_var, 0, transforms + [transform])) self.append_pruned_vars(_filter, 1, transforms)
self._visit(filter_var, 0) self._visit_and_search(_in_var, channel_axis, transforms)
self._visit_and_search(_out, channel_axis, transforms)
for op in filter_var.outputs(): elif var == _out:
self._prune_op(op, filter_var, 0, transforms + [transform]) assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
output_var = self.op.outputs("Output")[0] self.append_pruned_vars(_filter, 0, transforms)
next_ops = output_var.outputs() self.append_pruned_vars(_filter, 1, transforms)
for op in next_ops: self._visit_and_search(_filter, 0, transforms)
self._prune_op(op, output_var, channel_axis, # It will not pruning number of kernels in depthwise conv2d,
transforms + [transform]) # so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1)
self._visit_and_search(_in_var, channel_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class mul(PruneWorker): class mul(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(mul, self).__init__(op, pruned_params, visited) super(mul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -570,7 +650,7 @@ class mul(PruneWorker): ...@@ -570,7 +650,7 @@ class mul(PruneWorker):
for i in pruned_idx: for i in pruned_idx:
idx += list(range_idx + i * feature_map_size) idx += list(range_idx + i * feature_map_size)
param_var = self.op.inputs("Y")[0] param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, idx)) self.append_pruned_vars(param_var, 0, idx)
for op in param_var.outputs(): for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx) self._prune_op(op, param_var, 0, pruned_idx)
...@@ -578,22 +658,36 @@ class mul(PruneWorker): ...@@ -578,22 +658,36 @@ class mul(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class matmul(PruneWorker): class matmul(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul, self).__init__(op, pruned_params, visited) super(matmul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X") and pruned_axis == 1: x = self.op.inputs("X")[0]
param_var = self.op.inputs("Y")[0] y = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, pruned_idx)) out = self.op.outputs("Out")[0]
if var == x and pruned_axis == 1:
self.append_pruned_vars(y, 0, pruned_idx)
self._visit_and_search(y, 0, pruned_idx)
if var == out:
if pruned_axis == 0:
self.append_pruned_vars(x, 0, pruned_idx)
self._visit_and_search(x, 0, pruned_idx)
elif pruned_axis == 1:
self.append_pruned_vars(y, 1, pruned_idx)
self._visit_and_search(y, 1, pruned_idx)
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx) @PRUNE_WORKER.register
class matmul_v2(matmul):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul_v2, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class scale(PruneWorker): class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(scale, self).__init__(op, pruned_params, visited) super(scale, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -608,34 +702,34 @@ class scale(PruneWorker): ...@@ -608,34 +702,34 @@ class scale(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class momentum(PruneWorker): class momentum(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(momentum, self).__init__(op, pruned_params, visited) super(momentum, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"): if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
velocity_var = self.op.inputs("Velocity")[0] velocity_var = self.op.inputs("Velocity")[0]
self.pruned_params.append((velocity_var, pruned_axis, pruned_idx)) self.append_pruned_vars(velocity_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class adam(PruneWorker): class adam(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited, skip_stranger):
super(adam, self).__init__(op, pruned_params, visited) super(adam, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"): if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
moment1_var = self.op.inputs("Moment1")[0] moment1_var = self.op.inputs("Moment1")[0]
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx)) self.append_pruned_vars(moment1_var, pruned_axis, pruned_idx)
moment2_var = self.op.inputs("Moment2")[0] moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx)) self.append_pruned_vars(moment2_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class affine_channel(PruneWorker): class affine_channel(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(affine_channel, self).__init__(op, pruned_params, visited) super(affine_channel, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Out")) and ( if (var not in self.op.outputs("Out")) and (
...@@ -653,7 +747,7 @@ class affine_channel(PruneWorker): ...@@ -653,7 +747,7 @@ class affine_channel(PruneWorker):
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
for op in param_var.outputs(): for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx) self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx)) self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis) self._visit(out_var, pruned_axis)
...@@ -664,11 +758,12 @@ class affine_channel(PruneWorker): ...@@ -664,11 +758,12 @@ class affine_channel(PruneWorker):
@PRUNE_WORKER.register @PRUNE_WORKER.register
class flatten_contiguous_range(PruneWorker): class flatten_contiguous_range(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited, skip_stranger):
super(flatten_contiguous_range, self).__init__(op, pruned_params, super(flatten_contiguous_range, self).__init__(op, pruned_params,
visited) visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms): def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis") start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis") stop_axis = self.op.attr("stop_axis")
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -690,3 +785,58 @@ class flatten_contiguous_range(PruneWorker): ...@@ -690,3 +785,58 @@ class flatten_contiguous_range(PruneWorker):
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis, self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform]) transforms + [transform])
@PRUNE_WORKER.register
class squeeze2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(squeeze2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axes = self.op.attr("axes")
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
if axes is None or len(axes) == 0:
axes = [i for i, axis in enumerate(in_var.shape()) if axis == 1]
squeeze_num = 0
if in_var == var:
for axis in axes:
assert axis != pruned_axis, "Can not pruning axis that will be squeezed."
if axis < pruned_axis:
squeeze_num += 1
pruned_axis -= squeeze_num
self._visit_and_search(out_var, pruned_axis, transforms)
elif out_var == var:
for axis in axes:
if axis <= pruned_axis:
pruned_axis += 1
self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class unsqueeze2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(unsqueeze2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axes = self.op.attr("axes")
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
assert (axes is not None)
squeeze_num = 0
if in_var == var:
for axis in axes:
if axis <= pruned_axis:
pruned_axis += 1
self._visit_and_search(out_var, pruned_axis, transforms)
elif out_var == var:
for axis in axes:
if axis < pruned_axis:
squeeze_num += 1
pruned_axis -= squeeze_num
self._visit_and_search(in_var, pruned_axis, transforms)
...@@ -18,7 +18,7 @@ import copy ...@@ -18,7 +18,7 @@ import copy
import numpy as np import numpy as np
from functools import reduce from functools import reduce
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from .group_param import collect_convs from .collections import StaticPruningCollections
from .criterion import CRITERION from .criterion import CRITERION
from .idx_selector import IDX_SELECTOR from .idx_selector import IDX_SELECTOR
from ..common import get_logger from ..common import get_logger
...@@ -79,38 +79,28 @@ class Pruner(): ...@@ -79,38 +79,28 @@ class Pruner():
Returns: Returns:
tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters. tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
""" """
self.pruned_list = [] self.pruned_list = []
graph = GraphWrapper(program.clone()) graph = GraphWrapper(program.clone())
param_backup = {} if param_backup else None param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None param_shape_backup = {} if param_shape_backup else None
pruned_params = [] pruned_params = []
visited = {} collections = StaticPruningCollections(params, graph)
for param, ratio in zip(params, ratios): ratios = dict(zip(params, ratios))
_logger.info("pruning: {}".format(param)) values = {}
if graph.var(param) is None: for _collection in collections:
_logger.warn( for _var_name in _collection.variables():
"Variable[{}] to be pruned is not in current graph.".format( var = scope.find_var(_var_name)
param))
continue
group = collect_convs([param], graph,
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
assert (
not self.pruned_weights), "The weights have been pruned once."
group_values = []
for name, axis, pruned_idx in group:
var = scope.find_var(name)
if var is not None: if var is not None:
values = np.array(var.get_tensor()) value = np.array(var.get_tensor())
group_values.append((name, values, axis, pruned_idx)) values[_var_name] = value
scores = self.criterion(group_values, for _collection in collections:
graph) # [(name, axis, score, pruned_idx)] scores = self.criterion(_collection, values, graph)
g = self._transform(self.idx_selector(scores, ratio)) idx = self.idx_selector(_collection, scores,
pruned_params.extend(g) ratios) # name, axis, idx, transform
idx = self._transform(idx)
pruned_params.extend(idx)
merge_pruned_params = {} merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params: for param, pruned_axis, pruned_idx in pruned_params:
...@@ -124,32 +114,35 @@ class Pruner(): ...@@ -124,32 +114,35 @@ class Pruner():
pruned_idx = np.concatenate(merge_pruned_params[param_name][ pruned_idx = np.concatenate(merge_pruned_params[param_name][
pruned_axis]) pruned_axis])
param = graph.var(param_name) param = graph.var(param_name)
_groups = 1
if not lazy: if not lazy:
_logger.debug("{}\t{}\t{}\t{}".format( # update groups of conv2d
param.name(), pruned_axis, if pruned_axis == 1:
param.shape()[pruned_axis], len(pruned_idx))) for op in param.outputs():
origin_shape = copy.deepcopy(param.shape()) if op.type() in ["conv2d", "depthwise_conv2d"
if param_shape_backup is not None: ] and op.attr("groups") > 1:
param_shape_backup[param.name()] = origin_shape _groups = op.attr("groups")
new_shape = list(param.shape()) _filter_num = param.shape()[1]
new_shape[pruned_axis] -= len(pruned_idx) new_groups = int(
param.set_shape(new_shape) (_groups * _filter_num - len(pruned_idx)) /
# update groups of depthwise conv2d _filter_num)
for op in param.outputs(): _logger.info(
if op.type() in ["conv2d", "depthwise_conv2d" f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};"
] and op.attr("groups") > 1: )
assert origin_shape[ op.set_attr("groups", new_groups)
1] == 1, "Only support for depthwise when groups > 1." if _groups == 1:
new_groups = int( origin_shape = copy.deepcopy(param.shape())
op.attr("groups") * new_shape[pruned_axis] / if param_shape_backup is not None:
origin_shape[pruned_axis]) param_shape_backup[param.name()] = origin_shape
_logger.debug( new_shape = list(param.shape())
f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}" new_shape[pruned_axis] -= len(pruned_idx)
) param.set_shape(new_shape)
op.set_attr("groups", new_groups)
if not only_graph and (_groups == 1 or pruned_axis != 1):
if not only_graph: _var = scope.find_var(param.name())
param_t = scope.find_var(param.name()).get_tensor() if _var is None:
continue
param_t = _var.get_tensor()
if param_backup is not None and ( if param_backup is not None and (
param.name() not in param_backup): param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy( param_backup[param.name()] = copy.deepcopy(
...@@ -162,40 +155,42 @@ class Pruner(): ...@@ -162,40 +155,42 @@ class Pruner():
lazy=lazy) lazy=lazy)
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
except IndexError as e: except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format( _logger.error(
param.name(), e)) "Pruning {} with shape {} on axis {}, but get [{}]; ".
format(param.name(),
param_t.shape(), pruned_axis, e))
graph.infer_shape() graph.infer_shape()
self.pruned_weights = (not only_graph) self.pruned_weights = (not only_graph)
return graph.program, param_backup, param_shape_backup return graph.program, param_backup, param_shape_backup
def _transform(self, group): def _transform(self, items):
ret = [] ret = []
for name, axis, pruned_idx, transforms in group: for name, axis, pruned_idx, transforms in items:
src = pruned_idx src = pruned_idx
for trans in transforms: for trans in transforms:
src_start = trans['src_start'] src_start = trans['src_start']
src_end = trans['src_end'] src_end = trans['src_end']
src_len = src_end - src_start
target_start = trans['target_start'] target_start = trans['target_start']
target_end = trans['target_end'] target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
target = [] target = []
for idx in src: for idx in src:
if idx >= src_start and idx < src_end: if idx >= src_start and idx < src_end:
idx -= src_start idx -= src_start
idx += target_start target.extend(list(idx + starts))
if idx < target_end:
target.append(idx)
src = target src = target
ret.append((name, axis, src)) ret.append((name, axis, src))
return ret return ret
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
""" """
Pruning a array by indexes on given axis. Pruning a array by indices on given axis.
Args: Args:
tensor(numpy.array): The target array to be pruned. tensor(numpy.array): The target array to be pruned.
pruned_idx(list<int>): The indexes to be pruned. pruned_idx(list<int>): The indices to be pruned.
pruned_axis(int): The axis of given array to be pruned on. pruned_axis(int): The axis of given array to be pruned on.
lazy(bool): True means setting the pruned elements to zero. lazy(bool): True means setting the pruned elements to zero.
False means remove the pruned elements from memory. False means remove the pruned elements from memory.
......
...@@ -98,7 +98,7 @@ def sensitivity(program, ...@@ -98,7 +98,7 @@ def sensitivity(program,
params=[name], params=[name],
ratios=[ratio], ratios=[ratio],
place=place, place=place,
lazy=True, lazy=False,
only_graph=False, only_graph=False,
param_backup=True) param_backup=True)
if eval_args is None: if eval_args is None:
...@@ -108,7 +108,6 @@ def sensitivity(program, ...@@ -108,7 +108,6 @@ def sensitivity(program,
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio, _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss)) loss))
sensitivities[name][ratio] = loss sensitivities[name][ratio] = loss
_save_sensitivities(sensitivities, sensitivities_file) _save_sensitivities(sensitivities, sensitivities_file)
......
...@@ -99,13 +99,74 @@ class TestFilterPruner(unittest.TestCase): ...@@ -99,13 +99,74 @@ class TestFilterPruner(unittest.TestCase):
plan = pruner.sensitive_prune(0.01, align=4) plan = pruner.sensitive_prune(0.01, align=4)
for param in net.parameters(): for param in net.parameters():
if param.name in self._param_names: if param.name in self._param_names:
print(f"name: {param.name}; shape: {param.shape}")
self.assertTrue(param.shape[0] % 4 == 0) self.assertTrue(param.shape[0] % 4 == 0)
pruner.restore() pruner.restore()
class TestPruningGroupConv2d(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestPruningGroupConv2d, self).__init__(methodName)
def runTest(self):
with fluid.unique_name.guard():
net = paddle.vision.models.mobilenet_v1()
ratios = {}
for param in net.parameters():
if len(param.shape) == 4:
ratios[param.name] = 0.5
pruners = []
pruner = L1NormFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
pruner = FPGMFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
pruner = L2NormFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
shapes = {}
for pruner in pruners:
plan = pruner.prune_vars(ratios, 0)
for param in net.parameters():
if param.name not in shapes:
shapes[param.name] = param.shape
assert (shapes[param.name] == param.shape)
pruner.restore()
#class TestStrideTransform(unittest.TestCase):
# def __init__(self, methodName='runTest'):
# super(TestStrideTransform, self).__init__(methodName)
#
# def runTest(self):
# with fluid.unique_name.guard():
#
# net = paddle.vision.models.mobilenet_v1()
# ratios = {}
# for param in net.parameters():
# if len(param.shape) == 4:
# ratios[param.name] = 0.5
# pruners = []
# pruner = L1NormFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
# pruner = FPGMFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
# pruner = L2NormFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
#
# shapes = {}
# for pruner in pruners:
# plan = pruner.prune_vars(ratios, 0)
# for param in net.parameters():
# if param.name not in shapes:
# shapes[param.name] = param.shape
# assert(shapes[param.name] == param.shape)
# pruner.restore()
def add_cases(suite): def add_cases(suite):
suite.addTest(TestStatus()) # suite.addTest(TestStatus())
suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"])) # suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"]))
suite.addTest(TestPruningGroupConv2d())
def load_tests(loader, standard_tests, pattern): def load_tests(loader, standard_tests, pattern):
......
...@@ -43,7 +43,7 @@ class TestPrune(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestPrune(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
model = net(pretrained=False) model = net(pretrained=False)
pruner = L1NormFilterPruner(model, [1, 3, 16, 16]) pruner = L1NormFilterPruner(model, [1, 3, 16, 16])
pruner.prune_vars(ratios, [0]) pruner.prune_vars(ratios, 0)
shapes = {} shapes = {}
for param in model.parameters(): for param in model.parameters():
shapes[param.name] = param.shape shapes[param.name] = param.shape
......
...@@ -25,7 +25,7 @@ class TestWalker(unittest.TestCase): ...@@ -25,7 +25,7 @@ class TestWalker(unittest.TestCase):
net = Net() net = Net()
x = np.random.uniform(-1, 1, x_shape).astype('float32') x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)]) pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0]) pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0)
self.assertTrue(net.linear.weight.shape == [5400, 5]) self.assertTrue(net.linear.weight.shape == [5400, 5])
......
...@@ -8,14 +8,14 @@ from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask ...@@ -8,14 +8,14 @@ from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase): class TestPruningPlan(unittest.TestCase):
def testAdd(self): def testAdd(self):
plan = PruningPlan() plan = PruningPlan()
mask = PruningMask([0], [0, 1, 1], 0.33) mask = PruningMask(0, [0, 1, 1], 0.33, None)
plan.add("a", mask) plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33) mask = PruningMask(0, [0, 1, 0], 0.33, None)
plan.add("a", mask) plan.add("a", mask)
a_mask = plan.masks["a"] a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1) self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 0]) self.assertTrue(a_mask[0].mask == [0, 1, 0])
self.assertTrue(a_mask[0].dims == [0]) self.assertTrue(a_mask[0].dims == 0)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,7 +16,7 @@ sys.path.append("../") ...@@ -16,7 +16,7 @@ sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from layers import conv_bn_layer from layers import conv_bn_layer
from paddleslim.prune import collect_convs from paddleslim.prune import StaticPruningCollections
from static_case import StaticCase from static_case import StaticCase
...@@ -41,12 +41,9 @@ class TestPrune(StaticCase): ...@@ -41,12 +41,9 @@ class TestPrune(StaticCase):
sum2 = conv4 + sum1 sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
collected_groups = collect_convs( collections = StaticPruningCollections(
["conv1_weights", "conv2_weights", "conv3_weights", "dummy"], ["conv1_weights", "conv2_weights", "conv3_weights", "dummy"],
main_program) main_program)
while [] in collected_groups:
collected_groups.remove([])
print(collected_groups)
params = set([ params = set([
param.name for param in main_program.all_parameters() param.name for param in main_program.all_parameters()
...@@ -58,14 +55,13 @@ class TestPrune(StaticCase): ...@@ -58,14 +55,13 @@ class TestPrune(StaticCase):
('conv4_weights', 0), ('conv5_weights', 1)], ('conv4_weights', 0), ('conv5_weights', 1)],
[('conv3_weights', 0), ('conv4_weights', 1)]] [('conv3_weights', 0), ('conv4_weights', 1)]]
self.assertTrue(len(collected_groups) == len(expected_groups)) self.assertTrue(len(collections._collections) == len(expected_groups))
for _collected, _expected in zip(collected_groups, expected_groups): for _collected, _expected in zip(collections, expected_groups):
for _name, _axis, _ in _collected: for _info in _collected.all_pruning_details():
_name = _info.name
_axis = _info.axis
if _name in params: if _name in params:
self.assertTrue((_name, _axis) in _expected) self.assertTrue((_name, _axis) in _expected)
for _name, _axis in _expected:
if _name in params:
self.assertTrue((_name, _axis, []) in _collected)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
import os
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from static_case import StaticCase ...@@ -22,6 +23,7 @@ from static_case import StaticCase
from layers import conv_bn_layer from layers import conv_bn_layer
import random import random
from paddleslim.core import GraphWrapper from paddleslim.core import GraphWrapper
from paddleslim.prune.prune_worker import *
class TestPrune(StaticCase): class TestPrune(StaticCase):
...@@ -35,53 +37,54 @@ class TestPrune(StaticCase): ...@@ -35,53 +37,54 @@ class TestPrune(StaticCase):
# #
# X: prune output channels # X: prune output channels
# O: prune input channels # O: prune input channels
with fluid.program_guard(main_program, startup_program): with fluid.unique_name.guard():
input = fluid.data(name="image", shape=[None, 3, 16, 16]) with fluid.program_guard(main_program, startup_program):
label = fluid.data(name='label', shape=[None, 1], dtype='int64') input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu') label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu') conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
sum1 = conv1 + conv2 conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu')
conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6') sum1 = conv1 + conv2
conv4 = conv_bn_layer(conv3, 8, 3, "conv4") conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6')
sum2 = conv4 + sum1 conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
flag = fluid.layers.fill_constant([1], value=1, dtype='int32')
rand_flag = paddle.randint(2, dtype='int32') flag = fluid.layers.fill_constant([1], value=1, dtype='int32')
cond = fluid.layers.less_than(x=flag, y=rand_flag) rand_flag = paddle.randint(2, dtype='int32')
cond_output = fluid.layers.create_global_var( cond = fluid.layers.less_than(x=flag, y=rand_flag)
shape=[1], cond_output = fluid.layers.create_global_var(
value=0.0, shape=[1],
dtype='float32', value=0.0,
persistable=False, dtype='float32',
name='cond_output') persistable=False,
name='cond_output')
def cond_block1():
cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1") def cond_block1():
fluid.layers.assign(input=cond_conv, output=cond_output) cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1")
fluid.layers.assign(input=cond_conv, output=cond_output)
def cond_block2():
cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1") def cond_block2():
cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2") cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1")
fluid.layers.assign(input=cond_conv2, output=cond_output) cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2")
fluid.layers.assign(input=cond_conv2, output=cond_output)
fluid.layers.cond(cond, cond_block1, cond_block2)
sum3 = fluid.layers.sum([sum2, cond_output]) fluid.layers.cond(cond, cond_block1, cond_block2)
sum3 = fluid.layers.sum([sum2, cond_output])
conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
sub1 = conv6 - sum3 conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
mult = sub1 * sub1 sub1 = conv6 - sum3
conv7 = conv_bn_layer( mult = sub1 * sub1
mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False) conv7 = conv_bn_layer(
floored = fluid.layers.floor(conv7) mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False)
scaled = fluid.layers.scale(floored) floored = fluid.layers.floor(conv7)
concated = fluid.layers.concat([scaled, mult], axis=1) scaled = fluid.layers.scale(floored)
conv8 = conv_bn_layer(concated, 8, 3, "conv8") concated = fluid.layers.concat([scaled, mult], axis=1)
predict = fluid.layers.fc(input=conv8, size=10, act='softmax') conv8 = conv_bn_layer(concated, 8, 3, "conv8")
cost = fluid.layers.cross_entropy(input=predict, label=label) predict = fluid.layers.fc(input=conv8, size=10, act='softmax')
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost) adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
adam_optimizer.minimize(avg_cost) avg_cost = fluid.layers.mean(cost)
adam_optimizer.minimize(avg_cost)
params = [] params = []
for param in main_program.all_parameters(): for param in main_program.all_parameters():
...@@ -117,5 +120,439 @@ class TestPrune(StaticCase): ...@@ -117,5 +120,439 @@ class TestPrune(StaticCase):
fetch_list=[cost.name]) fetch_list=[cost.name])
class TestUnsqueeze2(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
out = paddle.unsqueeze(conv1, axis=[0])
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("unsqueeze2")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 2, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning in
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(in_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {})
class TestSqueeze2(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
out = paddle.squeeze(conv1)
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("squeeze2")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 0, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning in
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(in_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {})
class TestUnsupportAndDefault(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
# hit default pruning worker
cast1 = paddle.cast(conv1, dtype="int32")
# hit unsupported pruning worker
out = paddle.reshape(cast1, shape=[1, -1])
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("conv2d")
in_var = graph.var("conv1_weights")
op = in_var.outputs()[0]
# pruning input of conv op
pruned_params = []
ret = {}
os.environ['OPS_UNSUPPORTED'] = "reshape2"
worker = cls(op, pruned_params, {}, True)
hit_unsupported_op = False
try:
worker.prune(in_var, 0, [])
except UnsupportOpError as e:
hit_unsupported_op = True
print(e)
self.assertTrue(hit_unsupported_op)
class TestConv2d(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 6, 3, "conv1", groups=1, bias=True, act='relu')
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("conv2d")
weight_var = graph.var("conv1_weights")
in_var = graph.var("image")
op = in_var.outputs()[0]
out_var = op.outputs("Output")[0]
# pruning weights of conv op
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(weight_var, 0, [])
worker.prune(weight_var, 1, [])
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == {
'conv1_weights': [0, 1],
'conv1_out.b_0': [0],
'conv1_bn_scale': [0],
'conv1_bn_offset': [0],
'conv1_bn_mean': [0],
'conv1_bn_variance': [0]
})
# pruning out of conv op
pruned_params = []
ret = {}
worker = cls(op, pruned_params, visited={}, skip_stranger=True)
worker.prune(out_var, 1, [])
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == {'conv1_weights': [0]})
class TestPruneWorker(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.create_graph()
self.cases = []
self.set_cases()
def define_layer(self, input):
pass
def set_cases(self):
pass
def create_graph(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with paddle.fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[8, 8, 16, 16])
self.define_layer(input)
self.graph = GraphWrapper(main_program)
self.in_var = self.graph.var(self.input.name)
self.out_var = self.graph.var(self.output.name)
self.op = self.in_var.outputs()[0]
def check_in_out(self):
cls = PRUNE_WORKER.get(self.op.type())
if cls is None:
cls = PRUNE_WORKER.get("default_worker")
# pruning input of conv op
for _var, _axis, _ret in self.cases:
pruned_params = []
ret = {}
worker = cls(self.op, pruned_params, visited={}, skip_stranger=True)
try:
worker.prune(_var, _axis, [])
except UnsupportOpError as e:
print(e)
continue
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == _ret)
class TestConv2dTranspose(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d_transpose(
input, 6, 16, 3, name="conv1", bias_attr=False)
self.output = conv1
return conv1
def set_cases(self):
self.cases.append((self.in_var, 1, {'conv1.w_0': [0]}))
self.cases.append((self.out_var, 1, {'conv1.w_0': [1]}))
def test_prune(self):
self.check_in_out()
class TestElementwiseMul(TestPruneWorker):
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False)
conv2 = paddle.static.nn.conv2d(
input, 3, 3, name="conv2", bias_attr=False)
self.input = conv1
out = conv1 * conv2
conv3 = paddle.static.nn.conv2d(
out, 3, 3, name="conv3", bias_attr=False)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {
'conv2.tmp_0': [1],
'conv2.w_0': [0],
'conv3.w_0': [1]
}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'conv2.tmp_0': [1],
'conv2.w_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestActivation(TestPruneWorker):
def __init__(self, methodName="test_prune",
op=paddle.nn.functional.sigmoid):
super(TestActivation, self).__init__(methodName)
self.act = op
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False)
self.input = conv1
tmp = self.act(conv1)
self.output = tmp
conv2 = paddle.static.nn.conv2d(
tmp, 3, 3, name="conv2", bias_attr=False)
def set_cases(self):
self.cases.append((self.in_var, 1, {'conv2.w_0': [1]}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'conv2.w_0': [1]
}))
def test_prune(self):
self.check_in_out()
suite = unittest.TestSuite()
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_bilinear))
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_nearest))
suite.addTest(TestActivation(op=paddle.floor))
suite.addTest(TestActivation(op=paddle.scale))
suite.addTest(
TestActivation(op=paddle.fluid.layers.nn.uniform_random_batch_size_like))
class TestDepthwiseConv2d(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestDepthwiseConv2d, self).__init__(methodName)
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input,
input.shape[1],
3,
groups=input.shape[1],
name="conv1",
bias_attr=False)
self.output = conv1
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]}))
self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]}))
self.cases.append((weight_var, 0, {'conv1.w_0': [1]}))
def test_prune(self):
self.check_in_out()
class TestMul(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestMul, self).__init__(methodName)
def define_layer(self, input):
x = fluid.data(name="x", shape=[1, 4, 3, 3])
y = fluid.data(name="y", shape=[36, 7])
self.input = x
out = paddle.fluid.layers.mul(x, y)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]}))
def test_prune(self):
self.check_in_out()
class TestMatmul(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestMatmul, self).__init__(methodName)
def define_layer(self, input):
x = fluid.data(name="x", shape=[6, 8])
y = fluid.data(name="y", shape=[8, 7])
self.input = x
out = paddle.matmul(x, y)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0]}))
self.cases.append((self.out_var, 1, {'y': [1]}))
def test_prune(self):
self.check_in_out()
class TestSplit(TestPruneWorker):
def define_layer(self, input):
self.input = input
split1 = paddle.split(input, num_or_sections=2, axis=1, name=None)
self.output = split1[0]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 0, {}))
self.cases.append((self.out_var, 1, {}))
self.cases.append((self.out_var, 0, {}))
def test_prune(self):
self.check_in_out()
class TestMomentum(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Momentum()
opt.minimize(out)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_velocity_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAdam(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Adam()
opt.minimize(out)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_moment1_0': [0],
'conv1.w_0_moment2_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAffineChannel(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestAffineChannel, self).__init__(methodName)
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.input = conv1
scale = fluid.data(name="scale", shape=[conv1.shape[1]])
bias = fluid.data(name="bias", shape=[conv1.shape[1]])
out = paddle.fluid.layers.affine_channel(conv1, scale=scale, bias=bias)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'scale': [0], 'bias': [0]}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'scale': [0],
'bias': [0]
}))
def test_prune(self):
self.check_in_out()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -107,6 +107,7 @@ class TestSensitivity(StaticCase): ...@@ -107,6 +107,7 @@ class TestSensitivity(StaticCase):
sensitivities_file="./sensitivities_file_2", sensitivities_file="./sensitivities_file_2",
pruned_ratios=[0.1, 0.2, 0.3, 0.4]) pruned_ratios=[0.1, 0.2, 0.3, 0.4])
self.assertTrue(params_sens == origin_sens) self.assertTrue(params_sens == origin_sens)
self.assertTrue(sens == origin_sens) self.assertTrue(sens == origin_sens)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册