未验证 提交 b849d609 编写于 作者: W whs 提交者: GitHub

Fix pruning group conv2d (#720)

上级 d3aeda6f
......@@ -197,7 +197,10 @@ class OpWrapper(object):
bool|int|str|float|list: The attribute value. The return value
can be any valid attribute type.
"""
return self._op.attr(name)
if self._op.has_attr(name):
return self._op.attr(name)
else:
return None
class GraphWrapper(object):
......@@ -365,35 +368,6 @@ class GraphWrapper(object):
Update the groups of convolution layer according to current filters.
It is used after loading pruned parameters from file.
"""
head_op = []
visited = []
for op in self.ops():
if op.type() != 'conditional_block':
if len(self.pre_ops(op)) == 0:
head_op.append(op)
candidate_op = self.ops()
def recursive_infer(op, infer=False):
if op in candidate_op:
if op.type() != 'conditional_block':
if infer:
op._op.desc.infer_shape(op._op.block.desc)
else:
visited.append(op)
candidate_op.remove(op)
for next_op in self.next_ops(op):
recursive_infer(next_op)
# Find ops which not in the DAG, some ops, such as optimizer op,
# should be infered before normal cumputation ops.
for op in head_op:
recursive_infer(op, infer=False)
# Infer ops which not in the DAG firstly.
candidate_op = self.ops()
for op in candidate_op:
if op not in visited and op.type() != 'conditional_block':
op._op.desc.infer_shape(op._op.block.desc)
# Infer the remain ops in topological order.
for op in head_op:
recursive_infer(op, infer=True)
......@@ -9,14 +9,14 @@ from .var_group import *
from .pruning_plan import *
from .pruner import Pruner
from paddleslim.analysis import dygraph_flops as flops
from .var_group import VarGroup
from .var_group import DygraphPruningCollections
__all__ = ['Status', 'FilterPruner']
_logger = get_logger(__name__, logging.INFO)
CONV_OP_TYPE = paddle.nn.Conv2D
FILTER_DIM = [0]
FILTER_DIM = 0
CONV_WEIGHT_NAME = "weight"
SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose)
......@@ -59,16 +59,17 @@ class FilterPruner(Pruner):
def __init__(self, model, inputs, sen_file=None):
super(FilterPruner, self).__init__(model, inputs)
self._status = Status(sen_file)
# sensitive and var_group are just used in filter pruning
self.var_group = VarGroup(model, inputs)
# sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections(model, inputs)
# skip vars in:
# 1. depthwise conv2d layer
self.skip_vars = []
for sub_layer in model.sublayers():
if isinstance(sub_layer, SKIP_LAYERS) or (isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D) and
sub_layer._groups > 1):
#if isinstance(sub_layer, SKIP_LAYERS) or (isinstance(
# sub_layer, paddle.nn.layer.conv.Conv2D) and
# sub_layer._groups > 1):
if isinstance(sub_layer, SKIP_LAYERS):
for param in sub_layer.parameters():
self.skip_vars.append(param.name)
......@@ -170,11 +171,11 @@ class FilterPruner(Pruner):
break
return ratios
def _round_to(self, ratios, dims=[0], factor=8):
def _round_to(self, ratios, dims=0, factor=8):
ret = {}
for name in ratios:
ratio = ratios[name]
dim = self._var_shapes[name][dims[0]]
dim = self._var_shapes[name][dims]
remained = round((1 - ratio) * dim / factor) * factor
if remained == 0:
remained = factor
......@@ -186,14 +187,14 @@ class FilterPruner(Pruner):
def get_ratios_by_sensitivity(self,
pruned_flops,
align=None,
dims=[0],
dims=0,
skip_vars=[]):
"""
Get a group of ratios by sensitivities.
Args:
pruned_flops(float): The excepted rate of FLOPs to be pruned. It should be in range (0, 1).
align(int, optional): Round the size of each pruned dimension to multiple of 'align' if 'align' is not None. Default: None.
dims(list, optional): The dims to be pruned on. [0] means pruning channels of output for convolution. Default: [0].
dims(int, optional): The dims to be pruned on. 0 means pruning channels of output for convolution. Default: 0.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None.
Returns:
......@@ -201,7 +202,7 @@ class FilterPruner(Pruner):
"""
base_flops = flops(self.model, self.inputs)
_logger.debug("Base FLOPs: {}".format(base_flops))
_logger.info("Base FLOPs: {}".format(base_flops))
low = 0.
up = 1.0
history = set()
......@@ -214,7 +215,6 @@ class FilterPruner(Pruner):
ratios = self._round_to(ratios, dims=dims, factor=align)
plan = self.prune_vars(ratios, axis=dims)
c_flops = flops(self.model, self.inputs)
_logger.debug("FLOPs after pruning: {}".format(c_flops))
c_pruned_flops = (base_flops - c_flops) / base_flops
plan.restore(self.model)
_logger.debug("Seaching ratios, pruned FLOPs: {}".format(
......@@ -240,10 +240,9 @@ class FilterPruner(Pruner):
sensitivities = self._status.sensitivies
baseline = None
ratios = np.arange(0.1, 1, step=0.1)
for group in self.var_group.groups:
var_name = group[0][0]
dims = group[0][1]
for _collection in self.collections:
var_name = _collection.master_name
dims = _collection.master_axis
if target_vars is not None and var_name not in target_vars:
continue
if skip_vars is not None and var_name in skip_vars:
......@@ -282,7 +281,6 @@ class FilterPruner(Pruner):
self.restore()
ratios, pruned_flops = self.get_ratios_by_sensitivity(
pruned_flops, align=align, dims=FILTER_DIM, skip_vars=skip_vars)
_logger.debug("ratios: {}".format(ratios))
self.plan = self.prune_vars(ratios, FILTER_DIM)
self.plan._pruned_flops = pruned_flops
return self.plan
......@@ -291,73 +289,60 @@ class FilterPruner(Pruner):
if self.plan is not None:
self.plan.restore(self.model)
def cal_mask(self, var_name, pruned_ratio, group):
"""
{
var_name: {
'layer': sub_layer,
'var': variable,
'value': np.array([]),
'pruned_dims': [1],
}
}
"""
def cal_mask(self, pruned_ratio, collection):
raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_dims, pruned_ratio, apply="impretive"):
def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
"""
Pruning a variable.
Parameters:
var_name(str): The name of variable.
pruned_dims(list<int>): The axies to be pruned. For convolution with format [out_c, in_c, k, k],
'axis=[0]' means pruning filters and 'axis=[0, 1]' means pruning kernels.
pruned_axis(int): The axis to be pruned. For convolution with format [out_c, in_c, k, k],
'axis=0' means pruning filters.
pruned_ratio(float): The ratio of pruned values in one variable.
apply(str): How to apply pruning plan to graph. It can be 'impretive', 'lazy' or None. None
means just returning an instance of 'PruningPlan' but not applying it to graph.
Returns:
plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'.
"""
pruned_axis = pruned_axis[0] if isinstance(pruned_axis,
list) else pruned_axis
assert (isinstance(pruned_axis, int))
if var_name in self.skip_vars:
_logger.warn(
f"{var_name} is skiped beacause it is not support for pruning derectly."
f"{var_name} is skiped beacause it is not supported for pruning directly."
)
return
if isinstance(pruned_dims, int):
pruned_dims = [pruned_dims]
group = self.var_group.find_group(var_name, pruned_dims)
_logger.debug("found group with {}: {}".format(var_name, group))
collection = self.collections.find_collection_by_master(var_name,
pruned_axis)
plan = PruningPlan(self.model.full_name)
group_dict = {}
for sub_layer in self.model.sublayers():
for param in sub_layer.parameters(include_sublayers=False):
if param.name in group:
group_dict[param.name] = group[param.name]
# Varibales can be pruned on multiple axies.
for _item in group_dict[param.name]:
_item.update({
'layer': sub_layer,
'var': param,
'value': np.array(param.value().get_tensor())
})
_logger.debug(f"set value of {param.name} into group")
mask = self.cal_mask(var_name, pruned_ratio, group_dict)
for _name in group_dict:
if collection is None:
_logger.debug(
f"Can not find collection with master ['name': {var_name}, 'axis': {pruned_axis}]"
)
return plan
_logger.info(
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
)
mask = self.cal_mask(pruned_ratio, collection)
for _detail in collection.all_pruning_details():
# Varibales can be pruned on multiple axies.
for _item in group_dict[_name]:
src_mask = copy.deepcopy(mask)
dims = _item['pruned_dims']
transforms = _item['transforms']
var_shape = _item['var'].shape
if isinstance(dims, int):
dims = [dims]
for trans in transforms:
src_mask = self._transform_mask(src_mask, trans)
current_mask = src_mask
assert len(current_mask) == var_shape[dims[
0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape()
for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask
groups = _detail.op.attr('groups')
if groups is None or groups == 1:
assert len(current_mask) == var_shape[
_detail.
axis], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; axis: {_detail.axis}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_detail.name,
PruningMask(_detail.axis, current_mask, pruned_ratio,
_detail.op))
if apply == "lazy":
plan.apply(self.model, lazy=True)
elif apply == "impretive":
......@@ -371,17 +356,8 @@ class FilterPruner(Pruner):
target_start = transform['target_start']
target_end = transform['target_end']
target_len = transform['target_len']
stride = transform['stride']
mask = mask[src_start:src_end]
mask = mask.repeat(stride) if stride > 1 else mask
dst_mask = np.ones([target_len])
# for depthwise conv2d with:
# input shape: (1, 4, 32, 32)
# filter shape: (32, 1, 3, 3)
# groups: 4
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand
elif "stride" in transform:
......
......@@ -15,24 +15,38 @@ class FPGMFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None):
super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
dist_sum_list = []
for out_i in range(value.shape[0]):
dist_sum = self.get_distance_sum(value, out_i)
dist_sum_list.append(dist_sum)
scores = np.array(dist_sum_list)
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask
return mask.reshape(mask_shape)
def get_distance_sum(self, value, out_idx):
w = value.view()
......
......@@ -16,19 +16,32 @@ class L1NormFilterPruner(FilterPruner):
super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims
]
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1:
l1norm = l1norm.reshape([groups, -1])
l1norm = np.mean(l1norm, axis=1)
sorted_idx = l1norm.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask
return mask.reshape(mask_shape)
......@@ -16,22 +16,32 @@ class L2NormFilterPruner(FilterPruner):
super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
# find information of pruning on output channels
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims
]
# scores = np.mean(np.abs(value), axis=tuple(reduce_dims))
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[i] for i in pruned_dims]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask
return mask.reshape(mask_shape)
......@@ -39,11 +39,12 @@ class Pruner(object):
Args:
ratios(dict<str, float>): The key is the name of variable to be pruned and the
value is the pruned ratio.
axis(list): The dimensions to be pruned on.
axis(int): The dimension to be pruned on.
Returns:
plan(PruningPlan): The pruning plan.
"""
axis = axis[0] if isinstance(axis, list) else axis
global_plan = PruningPlan(self.model.full_name)
for var, ratio in ratios.items():
if not global_plan.contains(var, axis):
......
......@@ -10,27 +10,17 @@ __all__ = ['PruningPlan', 'PruningMask']
class PruningMask():
def __init__(self, dims, mask, ratio):
def __init__(self, dims, mask, ratio, op):
assert (isinstance(dims, int))
self._dims = dims
self._mask = mask
self._pruned_ratio = ratio
self._op = op
@property
def dims(self):
return self._dims
@dims.setter
def dims(self, value):
if not isinstance(value, collections.Iterator):
raise ValueError(
"The dims of PruningMask must be instance of collections.Iterator."
)
if self._mask is not None:
assert len(self._mask.shape) == len(
value
), "The length of value must be same with length of mask's shape in current PruningMask instance."
self._dims = list(value)
@property
def mask(self):
return self._mask
......@@ -128,8 +118,7 @@ class PruningPlan():
_logger.debug("Backup values of {} into buffers.".
format(param.name))
expand_mask_shape = [1] * len(value.shape)
for i in dims:
expand_mask_shape[i] = value.shape[i]
expand_mask_shape[dims] = value.shape[dims]
_logger.debug("Expanded mask shape: {}".format(
expand_mask_shape))
expand_mask = mask.reshape(expand_mask_shape).astype(
......@@ -158,13 +147,25 @@ class PruningPlan():
if param.name in self._masks:
for _mask in self._masks[param.name]:
dims = _mask.dims
assert (isinstance(dims, int))
mask = _mask.mask
assert len(
dims
) == 1, "Imperative mode only support for pruning on one dimension, but get dims {} when pruning parameter {}".format(
dims, param.name)
bool_mask = np.array(mask).astype(bool)
t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len(
value.shape) == 4:
filter_size = value.shape[1]
except_num = np.sum(bool_mask)
assert (except_num % filter_size == 0)
new_groups = int(except_num / filter_size)
sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = new_groups
_logger.info("change groups from {} to {} for {}.".
format(groups, new_groups, param.name))
continue
# The name of buffer can not contains "."
backup_name = param.name.replace(".", "_") + "_backup"
if backup_name not in sub_layer._buffers:
......@@ -172,9 +173,8 @@ class PruningPlan():
paddle.to_tensor(value))
_logger.debug("Backup values of {} into buffers.".
format(param.name))
bool_mask = np.array(mask).astype(bool)
pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value)
lambda data: data[bool_mask], dims, value)
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
......@@ -186,18 +186,6 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place)
if isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D
) and sub_layer._groups > 1 and len(param.shape) == 4:
assert param.shape[
1] == 1, "It just supports depthwise conv2d when groups > 1."
new_groups = int(bool_mask.sum() *
sub_layer._groups / len(bool_mask))
_logger.debug(
"Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups, new_groups))
sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = new_groups
# for training
if param.trainable:
......
......@@ -3,15 +3,15 @@ import logging
import paddle
from paddle.fluid.dygraph import TracedLayer
from paddleslim.core import GraphWrapper, dygraph2program
from paddleslim.prune import collect_convs
from paddleslim.prune import PruningCollections
from paddleslim.common import get_logger
__all__ = ["VarGroup"]
__all__ = ["DygraphPruningCollections"]
_logger = get_logger(__name__, level=logging.INFO)
class VarGroup():
class DygraphPruningCollections(PruningCollections):
"""
A tool used to parse dygraph and store information of variables' relationship.
Args:
......@@ -20,40 +20,29 @@ class VarGroup():
"""
def __init__(self, model, inputs):
self.groups = []
self._parse_model(model, inputs)
def _to_dict(self, group):
ret = {}
for _name, _axis, _transforms in group:
if isinstance(_axis, int):
_axis = [_axis]
if _name not in ret:
ret[_name] = []
# Variable can be pruned on multiple axies.
ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms})
return ret
def find_group(self, var_name, axis):
for group in self.groups:
for _name, _axis, _stride in group:
if isinstance(_axis, int):
_axis = [_axis]
if _name == var_name and _axis == axis:
return self._to_dict(group)
def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(program)
visited = {}
for name, param in model.named_parameters():
group = collect_convs([param.name], graph,
visited)[0] # [(name, axis, pruned_idx)]
if len(group) > 0:
self.groups.append(group)
_logger.info("Found {} groups.".format(len(self.groups)))
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 4
]
self._collections = self.create_pruning_collections(params, graph)
_logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {}
for param in model.parameters():
_name2values[param.name] = np.array(param.value().get_tensor())
for collection in self._collections:
collection.values = _name2values
def find_collection_by_master(self, var_name, axis):
for _collection in self._collections:
if _collection.master['name'] == var_name and _collection.master[
'axis'] == axis:
return _collection
def __str__(self):
return "\n".join([str(group) for group in self.groups])
return "\n".join(
[str(_collection) for _collection in self._collections])
......@@ -19,17 +19,16 @@ from .auto_pruner import *
from ..prune import auto_pruner
from .sensitive import *
from ..prune import sensitive
from .prune_walker import *
from ..prune import prune_walker
from .prune_worker import *
from ..prune import prune_worker
from .prune_io import *
from ..prune import prune_io
from .group_param import *
from ..prune import group_param
from .criterion import *
from ..prune import criterion
from .collections import *
from ..prune import collections
from .unstructured_pruner import *
from ..prune import unstructured_pruner
from .idx_selector import *
from ..prune import idx_selector
__all__ = []
......@@ -37,9 +36,9 @@ __all__ = []
__all__ += pruner.__all__
__all__ += auto_pruner.__all__
__all__ += sensitive.__all__
__all__ += prune_walker.__all__
__all__ += prune_worker.__all__
__all__ += prune_io.__all__
__all__ += group_param.__all__
__all__ += criterion.__all__
__all__ += unstructured_pruner.__all__
__all__ += idx_selector.__all__
__all__ += collections.__all__
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from ..core import GraphWrapper, VarWrapper
from ..common import get_logger
from .prune_worker import PRUNE_WORKER, UnsupportOpError
__all__ = [
'PruningDetails', 'PruningCollection', 'PruningCollections',
'StaticPruningCollections'
]
_logger = get_logger(__name__, level=logging.INFO)
class PruningDetails(object):
"""
The description of one pruning operation.
Args:
var(VarWrapper): The variable to be pruned.
axis(int): The axis to be pruned on.
transform(dict): Information used to convert pruned indices of master
tensor to indices of current tensor.
op(OpWrapper): The operator with current tensor as input.
is_parameter(bool): whether the tensor is parameter. Default: True.
"""
def __init__(self, var, axis, transform, op, is_parameter=True):
assert (isinstance(var, VarWrapper),
"name should be VarWrapper, but get type = ".format(type(var)))
assert (isinstance(axis, int))
self.name = var.name()
self.var = var
self.axis = axis
self.transform = transform
self.op = op
self.is_parameter = is_parameter
def __eq__(self, other):
if self.name != other.name:
return False
if self.axis != other.axis:
return False
if self.transform != other.transform:
return False
return True
class PruningCollection(object):
"""
A group of pruning operations.
conv1-->conv2-->batch_norm
For the network defined above, if weight of conv1 is pruned on 0-axis,
weight of'conv2' should be pruned on 1-axis. The pruning operations on 0-axis of
'conv1' and those on 1-axis of 'conv2' is a collection. And the {'name': conv1.weight_name, 'axis': 0}
is the master of current collection.
Args:
master(dict): The master pruning operation.
"""
def __init__(self, master=None):
self._master = master
self.master_name = master['name']
self.master_axis = master['axis']
self._nodes = {}
def variables(self):
"""
Get all tensors to be pruned in current collection.
Returns:
list<str>: Names of tensor to be pruned.
"""
return list(self._nodes.keys())
def add(self, node):
"""
Add a pruning operation into current collention.
Args:
node(PruningDetails): Pruning operation to be added into current collection.
"""
assert (isinstance(node, PruningDetails))
# the first added pruning operation will be master.
self._master = {
"name": node.name,
"axis": node.aixs
} if self._master is None else self._master
if node.name not in self._nodes:
self._nodes[node.name] = []
if node not in self._nodes[node.name]:
self._nodes[node.name].append(node)
@property
def master(self):
return self._master
def all_pruning_details(self):
"""
Get all pruning operations in current collection.
Returns:
list<PruningDetails>: Pruning operations.
"""
ret = []
for _items in self._nodes.values():
ret.extend(_items)
return ret
class PruningCollections(object):
def __init__(self):
self._collections = None
def __iter__(self):
return iter(self._collections)
def create_pruning_collections(self, params, graph, skip_stranger=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
Returns:
list<Group>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
visited = {}
collections = []
unsupported_warnings = set()
for _param in params:
pruned_params = []
param = graph.var(_param)
if param is None:
_logger.warning(
f"Couldn't find relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correct mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
continue
target_op = param.outputs()[0]
if target_op.type() == 'conditional_block':
for op in param.outputs():
if op.type() in PRUNE_WORKER._module_dict.keys():
cls = PRUNE_WORKER.get(op.type())
worker = cls(op,
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
if cls is None:
_logger.warning("No worker for operator: {}".format(
target_op.type()))
continue
worker = cls(target_op,
pruned_params=pruned_params,
visited=visited,
skip_stranger=skip_stranger)
try:
visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[])
except UnsupportOpError as e:
visited.clear()
visited.update(visited_backup)
unsupported_warnings.add(e.args)
else:
if len(pruned_params) != 0:
collection = PruningCollection(master=({
"name": param.name(),
"axis": 0
}))
for _param, _axis, _transform, _op in pruned_params:
collection.add(
PruningDetails(_param, _axis, _transform, _op))
collections.append(collection)
for warn in unsupported_warnings:
_logger.warning(warn)
self._collections = collections
return self._collections
class StaticPruningCollections(PruningCollections):
def __init__(self, params, graph, skip_stranger=True):
super(StaticPruningCollections, self).__init__()
self._collections = self.create_pruning_collections(
params, graph, skip_stranger=skip_stranger)
......@@ -27,7 +27,7 @@ CRITERION = Registry('criterion')
@CRITERION.register
def l1_norm(group, graph):
def l1_norm(group, values, graph):
"""Compute l1-norm scores of parameter on given axis.
This function return a list of parameters' l1-norm scores on given axis.
......@@ -35,28 +35,44 @@ def l1_norm(group, graph):
and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and
and `values` is the values of parameter and `axis` is the axis reducing on pruning on.
group(Group): A group of pruning operations.
values(dict): The key is the name of tensor in group, and the value of dict is the
values of tensor.
graph(GraphWrapper): The graph stores structure information of network.
Returns:
list: A list of tuple storing l1-norm on given axis.
dict: The key is name of tensor, the value is a dict
with axis as key and l1-norm scores as value.
"""
scores = []
for name, value, axis, pruned_idx in group:
scores = {}
for pruning_details in group.all_pruning_details():
name = pruning_details.name
if name not in values:
_logger.warning(
"The value of tensor '{}' is not found, so it will not be used when evaluating importance of pruned structures.".
format(name))
continue
value = values[name]
axis = pruning_details.axis
reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score, pruned_idx))
if name not in scores:
scores[name] = {}
scores[name][axis] = score
return scores
@CRITERION.register
def geometry_median(group, graph):
scores = []
name, value, axis, _ = group[0]
assert (len(value.shape) == 4)
def geometry_median(group, values, graph):
name = group.master["name"]
axis = group.master["axis"]
if name not in values:
_logger.warning("The value of tensor '{}' is not found.")
return None
value = values[name]
assert (len(value.shape) == 4,
"geometry_median only support for weight of conv2d.")
def get_distance_sum(value, out_idx):
w = value.view()
......@@ -73,31 +89,26 @@ def geometry_median(group, graph):
tmp = np.array(dist_sum_list)
for name, value, axis, idx in group:
scores.append((name, axis, tmp, idx))
scores = {}
for pruning_details in group.all_pruning_details():
name = pruning_details.name
axis = pruning_details.axis
if name not in scores:
scores[name] = {}
scores[name][axis] = tmp
return scores
@CRITERION.register
def bn_scale(group, graph):
"""Compute l1-norm scores of parameter on given axis.
This function return a list of parameters' l1-norm scores on given axis.
Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name
and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and
and `values` is the values of parameter and `axis` is the axis reducing on pruning on.
Returns:
list: A list of tuple storing l1-norm on given axis.
def bn_scale(group, values, graph):
"""Compute scores by scales of batch_norm layer.
"""
assert (isinstance(graph, GraphWrapper))
# step1: Get first convolution
conv_weight, value, axis, _ = group[0]
conv_weight = group.master["name"]
axis = group.master["axis"]
value = values[conv_weight]
param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0]
......@@ -111,12 +122,16 @@ def bn_scale(group, graph):
# steps3: Find scale of bn
score = None
for name, value, aixs, _ in group:
if bn_scale_param == name:
score = np.abs(value.reshape([-1]))
scores = []
for name, value, axis, idx in group:
scores.append((name, axis, score, idx))
if bn_scale_param not in values:
raise SystemExit("Can't find values of scales in BatchNorm.")
value = values[bn_scale_param]
score = np.abs(value.reshape([-1]))
scores = {}
for pruning_details in group.all_pruning_details():
name = pruning_details.name
axis = pruning_details.axis
if name not in scores:
scores[name] = {}
scores[name][axis] = score
return scores
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from ..core import GraphWrapper
from ..common import get_logger
from .prune_walker import PRUNE_WORKER
__all__ = ["collect_convs"]
_logger = get_logger(__name__, level=logging.INFO)
def collect_convs(params, graph, visited={}):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
groups = []
for _param in params:
pruned_params = []
param = graph.var(_param)
if param is None:
_logger.warning(
f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
groups.append([])
continue
target_op = param.outputs()[0]
if target_op.type() == 'conditional_block':
for op in param.outputs():
if op.type() in PRUNE_WORKER._module_dict.keys():
cls = PRUNE_WORKER.get(op.type())
walker = cls(op,
pruned_params=pruned_params,
visited=visited)
break
else:
cls = PRUNE_WORKER.get(target_op.type())
if cls is None:
_logger.info("No walker for operator: {}".format(target_op.type(
)))
groups.append(pruned_params)
continue
walker = cls(target_op,
pruned_params=pruned_params,
visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, pruned_idx in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis, pruned_idx))
if not repeat_group:
uniq_groups.append(simple_group)
return uniq_groups
......@@ -26,75 +26,80 @@ IDX_SELECTOR = Registry('idx_selector')
@IDX_SELECTOR.register
def default_idx_selector(group, ratio):
"""Get the pruned indexes by given ratio.
def default_idx_selector(group, scores, ratios):
"""Get the pruned indices by scores of master tensor.
This function return a list of parameters' pruned indexes on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name
and 'axis' is the axis pruning on and `indexes` is indexes to be pruned.
This function return a list of parameters' pruned indices on given axis.
Each element of list is a tuple with format (name, axis, indices)
in which 'name' is parameter's name and 'axis' is the axis pruning on and
`indices` is indices to be pruned.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
group(Group): A group of pruning operations.
scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value.
ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio.
Returns:
list: pruned indexes
list: pruned indices with format (name, axis, pruned_indices).
"""
name, axis, score, _ = group[
0] # sort channels by the first convolution's score
# sort channels by the master convolution's score
name = group.master["name"]
axis = group.master["axis"]
score = scores[name][axis]
# get max convolution groups attribution
max_groups = 1
for pruning_details in group.all_pruning_details():
groups = pruning_details.op.attr("groups")
if groups is not None and groups > max_groups:
max_groups = groups
if max_groups > 1:
score = score.reshape([max_groups, -1])
group_size = score.shape[1]
# get score for each group of channels
score = np.mean(score, axis=1)
sorted_idx = score.argsort()
ratio = ratios[name]
pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num]
idxs = []
for name, axis, score, transforms in group:
idxs.append((name, axis, pruned_idx, transforms))
return idxs
# convert indices of channel groups to indices of channels.
if max_groups > 1:
correct_idx = []
for idx in pruned_idx:
for offset in range(group_size):
correct_idx.append(idx * group_size + offset)
pruned_idx = correct_idx[:]
ret = []
for _pruning_details in group.all_pruning_details():
ret.append((_pruning_details.name, _pruning_details.axis, pruned_idx,
_pruning_details.transform))
return ret
@IDX_SELECTOR.register
def optimal_threshold(group, ratio):
"""Get the pruned indexes by given ratio.
def optimal_threshold(group, scores, ratios):
"""Get the pruned indices by scores of master tensor.
This function return a list of parameters' pruned indexes on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name
and 'axis' is the axis pruning on and `indexes` is indexes to be pruned.
This function return a list of parameters' pruned indices on given axis.
Each element of list is a tuple with format (name, axis, indices)
in which 'name' is parameter's name and 'axis' is the axis pruning on and
`indices` is indices to be pruned.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
group(Group): A group of pruning operations.
scores(dict): The key is name of tensor, the value is a dict with axis as key and scores as value.
ratios(dict): The pruned ratio of each tensor. The key is name of tensor and the value is the pruned ratio.
Returns:
list: pruned indexes
list: pruned indices with format (name, axis, pruned_indices).
"""
name, axis, score, _ = group[
0] # sort channels by the first convolution's score
# sort channels by the master tensor
name = group.master["name"]
axis = group.master["axis"]
score = scores[name][axis]
ratio = ratios[name]
score[score < 1e-18] = 1e-18
score_sorted = np.sort(score)
......@@ -110,6 +115,7 @@ def optimal_threshold(group, ratio):
pruned_idx = np.squeeze(np.argwhere(score < th))
idxs = []
for name, axis, score, transforms in group:
idxs.append((name, axis, pruned_idx, transforms))
for _pruning_details in group.all_pruning_details():
idxs.append((_pruning_details.name, _pruning_details.axis, pruned_idx,
_pruning_details.transform))
return idxs
......@@ -12,35 +12,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import numpy as np
from ..core import Registry
from ..common import get_logger
__all__ = ["PRUNE_WORKER", "conv2d"]
__all__ = ["PRUNE_WORKER", "conv2d", "UnsupportOpError"]
_logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
SKIP_OPS = ["conditional_block"]
SKIPPED_OPS = ['shape', 'reduce_mean']
# operators in OPS_UNCHANGE_SHAPE will be visited by default worker
# who keep shape of output same with shape of input.
OPS_UNCHANGE_SHAPE = os.getenv('OPS_UNCHANGE_SHAPE', None)
OPS_UNCHANGE_SHAPE = [] if OPS_UNCHANGE_SHAPE is None else OPS_UNCHANGE_SHAPE.strip(
).split(",")
OPS_UNCHANGE_SHAPE += [
'nearest_interp_v2',
'roi_align',
'sigmoid',
'swish',
'pad3d',
'bilinear_interp_v2',
'dropout',
'cast',
'hard_swish',
'hard_sigmoid',
]
class UnsupportOpError(Exception):
pass
class PruneWorker(object):
def __init__(self, op, pruned_params=[], visited={}):
def __init__(self, op, pruned_params, visited, skip_stranger=True):
"""
A wrapper of operator used to infer the information of all the related variables.
Args:
op(Operator): The operator to be pruned.
pruned_params(list): The list to store the information of pruning that infered by walker.
pruned_params(list): The list to store the information of pruning that infered by worker.
visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name.
skip_stranger(bool): Whether to raise exception when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default waorker. Default: True.
Return: A instance of PruneWalker.
Return: A instance of PruneWorker.
"""
self.op = op
self.pruned_params = pruned_params
self.visited = visited
self.skip_stranger = skip_stranger
self.ops_unsupported = os.getenv('OPS_UNSUPPORTED', None)
self.ops_unsupported = [] if self.ops_unsupported is None else self.ops_unsupported.strip(
).split(",")
def prune(self, var, pruned_axis, pruned_idx):
"""
......@@ -49,7 +77,7 @@ class PruneWorker(object):
Args:
var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable.
"""
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
......@@ -82,29 +110,36 @@ class PruneWorker(object):
return
if visited is not None:
self.visited = visited
if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
cls = PRUNE_WORKER.get(op.type())
if cls is None:
if op.type() in SKIP_OPS:
_logger.warn("Skip operator [{}]".format(op.type()))
if op.type() in SKIPPED_OPS:
return
if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger:
cls = PRUNE_WORKER.get("default_worker")
else:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
# _logger.warn(
# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
# format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
format(self.op, op, pruned_axis, var.name(), pruned_idx))
_logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
)
walker = cls(op, pruned_params=self.pruned_params, visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
worker.prune(var, pruned_axis, pruned_idx)
def append_pruned_vars(self, var, axis, transforms):
self.pruned_params.append((var, axis, transforms, self.op))
@PRUNE_WORKER.register
class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(conv2d, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(conv2d, self).__init__(op, pruned_params, visited, skip_stranger)
def _is_depthwise_conv(self, op):
data_format = self.op.attr("data_format")
......@@ -121,15 +156,17 @@ class conv2d(PruneWorker):
num_filters % num_channels == 0)
def _prune(self, var, pruned_axis, pruned_idx):
if self._is_depthwise_conv(self.op):
_logger.debug(f"Meet conv2d who is depthwise conv2d actually.")
walker = depthwise_conv2d(
self.op, self.pruned_params, visited=self.visited)
walker._prune(var, pruned_axis, pruned_idx)
return
worker = depthwise_conv2d(
self.op,
self.pruned_params,
visited=self.visited,
skip_stranger=self.skip_stranger)
return worker._prune(var, pruned_axis, pruned_idx)
data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
......@@ -137,56 +174,49 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
self.append_pruned_vars(filter_var, 1, pruned_idx)
if groups is None or groups == 1:
self._visit_and_search(filter_var, 1, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
self.append_pruned_vars(var, pruned_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if groups is None or groups == 1 or pruned_axis == 0:
self._visit_and_search(var, pruned_axis, pruned_idx)
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.append_pruned_vars(
self.op.inputs("Bias"), channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
self._visit_and_search(output_var, channel_axis, pruned_idx)
elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
self._visit_and_search(input_var, channel_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
@PRUNE_WORKER.register
class conv2d_transpose(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(conv2d_transpose, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(conv2d_transpose, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
data_format = self.op.attr("data_format")
......@@ -198,7 +228,7 @@ class conv2d_transpose(PruneWorker):
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.pruned_params.append((filter_var, 0, pruned_idx))
self.append_pruned_vars(filter_var, 0, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
......@@ -212,14 +242,14 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.pruned_params.append((filter_var, 1, pruned_idx))
self.append_pruned_vars(filter_var, 1, pruned_idx)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 1, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
......@@ -229,8 +259,9 @@ class conv2d_transpose(PruneWorker):
@PRUNE_WORKER.register
class batch_norm(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(batch_norm, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Y")) and (
......@@ -248,7 +279,7 @@ class batch_norm(PruneWorker):
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Y")[0]
self._visit(out_var, pruned_axis)
......@@ -259,13 +290,15 @@ class batch_norm(PruneWorker):
@PRUNE_WORKER.register
class sync_batch_norm(batch_norm):
def __init__(self, op, pruned_params, visited):
super(sync_batch_norm, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(sync_batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger)
class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_op, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
axis = self.op.attr("axis")
......@@ -286,7 +319,7 @@ class elementwise_op(PruneWorker):
# for bias
if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.pruned_params.append((in_var, actual_axis, pruned_idx))
self.append_pruned_vars(in_var, actual_axis, pruned_idx)
self._visit_and_search(in_var, actual_axis, pruned_idx)
else:
......@@ -301,8 +334,7 @@ class elementwise_op(PruneWorker):
if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and
in_var.shape()[0] == 1):
self.pruned_params.append(
(in_var, y_pruned_axis, pruned_idx))
self.append_pruned_vars(in_var, y_pruned_axis, pruned_idx)
self._visit_and_search(in_var, y_pruned_axis, pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
......@@ -318,26 +350,30 @@ class elementwise_op(PruneWorker):
@PRUNE_WORKER.register
class elementwise_add(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_add, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_add, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class elementwise_sub(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_sub, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_sub, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class elementwise_mul(elementwise_op):
def __init__(self, op, pruned_params, visited):
super(elementwise_mul, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(elementwise_mul, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(activation, self).__init__(op, pruned_params, visited,
skip_stranger)
self.input_name = "X"
self.output_name = "Out"
......@@ -351,9 +387,10 @@ class activation(PruneWorker):
@PRUNE_WORKER.register
class default_walker(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(default_walker, self).__init__(op, pruned_params, visited)
class default_worker(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(default_worker, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
......@@ -367,59 +404,62 @@ class default_walker(PruneWorker):
@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
super(uniform_random_batch_size_like, self).__init__(op, pruned_params,
visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(uniform_random_batch_size_like, self).__init__(
op, pruned_params, visited, skip_stranger)
self.input_name = "Input"
self.output_name = "Out"
@PRUNE_WORKER.register
class bilinear_interp(activation):
def __init__(self, op, pruned_params, visited):
super(bilinear_interp, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(bilinear_interp, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class nearest_interp(activation):
def __init__(self, op, pruned_params, visited):
super(nearest_interp, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(nearest_interp, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class relu(activation):
def __init__(self, op, pruned_params, visited):
super(relu, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(relu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class leaky_relu(activation):
def __init__(self, op, pruned_params, visited):
super(leaky_relu, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(leaky_relu, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class floor(activation):
def __init__(self, op, pruned_params, visited):
super(floor, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(floor, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class relu6(activation):
def __init__(self, op, pruned_params, visited):
super(relu6, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(relu6, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class pool2d(activation):
def __init__(self, op, pruned_params, visited):
super(pool2d, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(pool2d, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class sum(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(sum, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(sum, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
......@@ -440,10 +480,46 @@ class sum(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class split(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(split, self).__init__(op, pruned_params, visited, skip_stranger)
self.in_var = op.inputs("X")[0]
self.out_vars = op.outputs("Out")
self.axis = op.attr("axis")
self.num = op.attr("num")
def _prune(self, var, pruned_axis, transforms):
if var == self.in_var:
if pruned_axis != self.axis:
for out_var in self.out_vars:
self._visit_and_search(out_var, pruned_axis, transforms)
else:
raise UnsupportOpError(
"Unsupport pruning input of split operator directly.")
elif var in self.out_vars:
if pruned_axis != self.axis:
self._visit_and_search(self.in_var, pruned_axis, transforms)
else:
trans = {
"src_start": 0,
"src_end": var.shape()[pruned_axis],
"target_start": 0,
"target_end": self.in_var.shape()[pruned_axis],
"target_len": self.in_var.shape()[pruned_axis]
}
self._visit_and_search(self.in_var, pruned_axis,
transforms + [trans])
for out_var in self.out_vars:
if var != out_var:
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(concat, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis")
......@@ -513,52 +589,56 @@ class concat(PruneWorker):
@PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
assert var not in self.op.inputs(
"Filter"), "Unsupport for pruning depthwise conv2d directly."
assert var not in self.op.outputs(
"Output"
), "Unsupport for pruning output of depthwise conv2d directly."
_filter = self.op.inputs("Filter")[0]
_out = self.op.outputs("Output")[0]
_in_var = self.op.inputs("Input")[0]
data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
if var in self.op.inputs("Input"):
if var == _in_var:
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
groups = var.shape()[channel_axis]
filter_var = self.op.inputs("Filter")[0]
transform = {
'src_start': 0,
'src_end': var.shape()[pruned_axis],
'target_start': 0,
'target_end': filter_var.shape()[0],
'target_len': filter_var.shape()[0],
'stride': 1
}
self.pruned_params.append((filter_var, 0, transforms + [transform]))
self._visit(filter_var, 0)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, transforms + [transform])
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis,
transforms + [transform])
# pruning number of filters
self.append_pruned_vars(_filter, 0, transforms)
# kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms)
# It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms)
elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_in_var, channel_axis, transforms)
self._visit_and_search(_out, channel_axis, transforms)
elif var == _out:
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
self.append_pruned_vars(_filter, 0, transforms)
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms)
# It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1)
self._visit_and_search(_in_var, channel_axis, transforms)
@PRUNE_WORKER.register
class mul(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(mul, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(mul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
......@@ -570,7 +650,7 @@ class mul(PruneWorker):
for i in pruned_idx:
idx += list(range_idx + i * feature_map_size)
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, idx))
self.append_pruned_vars(param_var, 0, idx)
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
......@@ -578,22 +658,36 @@ class mul(PruneWorker):
@PRUNE_WORKER.register
class matmul(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(matmul, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X") and pruned_axis == 1:
param_var = self.op.inputs("Y")[0]
self.pruned_params.append((param_var, 0, pruned_idx))
x = self.op.inputs("X")[0]
y = self.op.inputs("Y")[0]
out = self.op.outputs("Out")[0]
if var == x and pruned_axis == 1:
self.append_pruned_vars(y, 0, pruned_idx)
self._visit_and_search(y, 0, pruned_idx)
if var == out:
if pruned_axis == 0:
self.append_pruned_vars(x, 0, pruned_idx)
self._visit_and_search(x, 0, pruned_idx)
elif pruned_axis == 1:
self.append_pruned_vars(y, 1, pruned_idx)
self._visit_and_search(y, 1, pruned_idx)
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
@PRUNE_WORKER.register
class matmul_v2(matmul):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul_v2, self).__init__(op, pruned_params, visited,
skip_stranger)
@PRUNE_WORKER.register
class scale(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(scale, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(scale, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("X"):
......@@ -608,34 +702,34 @@ class scale(PruneWorker):
@PRUNE_WORKER.register
class momentum(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(momentum, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(momentum, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
velocity_var = self.op.inputs("Velocity")[0]
self.pruned_params.append((velocity_var, pruned_axis, pruned_idx))
self.append_pruned_vars(velocity_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class adam(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(adam, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(adam, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.inputs("Param"):
_logger.debug("pruning momentum, var:{}".format(var.name()))
moment1_var = self.op.inputs("Moment1")[0]
self.pruned_params.append((moment1_var, pruned_axis, pruned_idx))
self.append_pruned_vars(moment1_var, pruned_axis, pruned_idx)
moment2_var = self.op.inputs("Moment2")[0]
self.pruned_params.append((moment2_var, pruned_axis, pruned_idx))
self.append_pruned_vars(moment2_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class affine_channel(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(affine_channel, self).__init__(op, pruned_params, visited)
def __init__(self, op, pruned_params, visited, skip_stranger):
super(affine_channel, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
if (var not in self.op.outputs("Out")) and (
......@@ -653,7 +747,7 @@ class affine_channel(PruneWorker):
param_var = self.op.inputs(param)[0]
for op in param_var.outputs():
self._prune_op(op, param_var, 0, pruned_idx)
self.pruned_params.append((param_var, 0, pruned_idx))
self.append_pruned_vars(param_var, 0, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
......@@ -664,11 +758,12 @@ class affine_channel(PruneWorker):
@PRUNE_WORKER.register
class flatten_contiguous_range(PruneWorker):
def __init__(self, op, pruned_params, visited):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(flatten_contiguous_range, self).__init__(op, pruned_params,
visited)
visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis")
if var in self.op.inputs("X"):
......@@ -690,3 +785,58 @@ class flatten_contiguous_range(PruneWorker):
for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform])
@PRUNE_WORKER.register
class squeeze2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(squeeze2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axes = self.op.attr("axes")
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
if axes is None or len(axes) == 0:
axes = [i for i, axis in enumerate(in_var.shape()) if axis == 1]
squeeze_num = 0
if in_var == var:
for axis in axes:
assert axis != pruned_axis, "Can not pruning axis that will be squeezed."
if axis < pruned_axis:
squeeze_num += 1
pruned_axis -= squeeze_num
self._visit_and_search(out_var, pruned_axis, transforms)
elif out_var == var:
for axis in axes:
if axis <= pruned_axis:
pruned_axis += 1
self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class unsqueeze2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(unsqueeze2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axes = self.op.attr("axes")
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
assert (axes is not None)
squeeze_num = 0
if in_var == var:
for axis in axes:
if axis <= pruned_axis:
pruned_axis += 1
self._visit_and_search(out_var, pruned_axis, transforms)
elif out_var == var:
for axis in axes:
if axis < pruned_axis:
squeeze_num += 1
pruned_axis -= squeeze_num
self._visit_and_search(in_var, pruned_axis, transforms)
......@@ -18,7 +18,7 @@ import copy
import numpy as np
from functools import reduce
from ..core import VarWrapper, OpWrapper, GraphWrapper
from .group_param import collect_convs
from .collections import StaticPruningCollections
from .criterion import CRITERION
from .idx_selector import IDX_SELECTOR
from ..common import get_logger
......@@ -79,38 +79,28 @@ class Pruner():
Returns:
tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
"""
self.pruned_list = []
graph = GraphWrapper(program.clone())
param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None
pruned_params = []
visited = {}
for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param))
if graph.var(param) is None:
_logger.warn(
"Variable[{}] to be pruned is not in current graph.".format(
param))
continue
group = collect_convs([param], graph,
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
assert (
not self.pruned_weights), "The weights have been pruned once."
group_values = []
for name, axis, pruned_idx in group:
var = scope.find_var(name)
collections = StaticPruningCollections(params, graph)
ratios = dict(zip(params, ratios))
values = {}
for _collection in collections:
for _var_name in _collection.variables():
var = scope.find_var(_var_name)
if var is not None:
values = np.array(var.get_tensor())
group_values.append((name, values, axis, pruned_idx))
value = np.array(var.get_tensor())
values[_var_name] = value
scores = self.criterion(group_values,
graph) # [(name, axis, score, pruned_idx)]
g = self._transform(self.idx_selector(scores, ratio))
pruned_params.extend(g)
for _collection in collections:
scores = self.criterion(_collection, values, graph)
idx = self.idx_selector(_collection, scores,
ratios) # name, axis, idx, transform
idx = self._transform(idx)
pruned_params.extend(idx)
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
......@@ -124,32 +114,35 @@ class Pruner():
pruned_idx = np.concatenate(merge_pruned_params[param_name][
pruned_axis])
param = graph.var(param_name)
_groups = 1
if not lazy:
_logger.debug("{}\t{}\t{}\t{}".format(
param.name(), pruned_axis,
param.shape()[pruned_axis], len(pruned_idx)))
origin_shape = copy.deepcopy(param.shape())
if param_shape_backup is not None:
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
# update groups of depthwise conv2d
for op in param.outputs():
if op.type() in ["conv2d", "depthwise_conv2d"
] and op.attr("groups") > 1:
assert origin_shape[
1] == 1, "Only support for depthwise when groups > 1."
new_groups = int(
op.attr("groups") * new_shape[pruned_axis] /
origin_shape[pruned_axis])
_logger.debug(
f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}"
)
op.set_attr("groups", new_groups)
if not only_graph:
param_t = scope.find_var(param.name()).get_tensor()
# update groups of conv2d
if pruned_axis == 1:
for op in param.outputs():
if op.type() in ["conv2d", "depthwise_conv2d"
] and op.attr("groups") > 1:
_groups = op.attr("groups")
_filter_num = param.shape()[1]
new_groups = int(
(_groups * _filter_num - len(pruned_idx)) /
_filter_num)
_logger.info(
f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};"
)
op.set_attr("groups", new_groups)
if _groups == 1:
origin_shape = copy.deepcopy(param.shape())
if param_shape_backup is not None:
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
if not only_graph and (_groups == 1 or pruned_axis != 1):
_var = scope.find_var(param.name())
if _var is None:
continue
param_t = _var.get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
......@@ -162,40 +155,42 @@ class Pruner():
lazy=lazy)
param_t.set(pruned_param, place)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(
param.name(), e))
_logger.error(
"Pruning {} with shape {} on axis {}, but get [{}]; ".
format(param.name(),
param_t.shape(), pruned_axis, e))
graph.infer_shape()
self.pruned_weights = (not only_graph)
return graph.program, param_backup, param_shape_backup
def _transform(self, group):
def _transform(self, items):
ret = []
for name, axis, pruned_idx, transforms in group:
for name, axis, pruned_idx, transforms in items:
src = pruned_idx
for trans in transforms:
src_start = trans['src_start']
src_end = trans['src_end']
src_len = src_end - src_start
target_start = trans['target_start']
target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
target = []
for idx in src:
if idx >= src_start and idx < src_end:
idx -= src_start
idx += target_start
if idx < target_end:
target.append(idx)
target.extend(list(idx + starts))
src = target
ret.append((name, axis, src))
return ret
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
"""
Pruning a array by indexes on given axis.
Pruning a array by indices on given axis.
Args:
tensor(numpy.array): The target array to be pruned.
pruned_idx(list<int>): The indexes to be pruned.
pruned_idx(list<int>): The indices to be pruned.
pruned_axis(int): The axis of given array to be pruned on.
lazy(bool): True means setting the pruned elements to zero.
False means remove the pruned elements from memory.
......
......@@ -98,7 +98,7 @@ def sensitivity(program,
params=[name],
ratios=[ratio],
place=place,
lazy=True,
lazy=False,
only_graph=False,
param_backup=True)
if eval_args is None:
......@@ -108,7 +108,6 @@ def sensitivity(program,
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
sensitivities[name][ratio] = loss
_save_sensitivities(sensitivities, sensitivities_file)
......
......@@ -99,13 +99,74 @@ class TestFilterPruner(unittest.TestCase):
plan = pruner.sensitive_prune(0.01, align=4)
for param in net.parameters():
if param.name in self._param_names:
print(f"name: {param.name}; shape: {param.shape}")
self.assertTrue(param.shape[0] % 4 == 0)
pruner.restore()
class TestPruningGroupConv2d(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(TestPruningGroupConv2d, self).__init__(methodName)
def runTest(self):
with fluid.unique_name.guard():
net = paddle.vision.models.mobilenet_v1()
ratios = {}
for param in net.parameters():
if len(param.shape) == 4:
ratios[param.name] = 0.5
pruners = []
pruner = L1NormFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
pruner = FPGMFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
pruner = L2NormFilterPruner(net, [1, 3, 128, 128])
pruners.append(pruner)
shapes = {}
for pruner in pruners:
plan = pruner.prune_vars(ratios, 0)
for param in net.parameters():
if param.name not in shapes:
shapes[param.name] = param.shape
assert (shapes[param.name] == param.shape)
pruner.restore()
#class TestStrideTransform(unittest.TestCase):
# def __init__(self, methodName='runTest'):
# super(TestStrideTransform, self).__init__(methodName)
#
# def runTest(self):
# with fluid.unique_name.guard():
#
# net = paddle.vision.models.mobilenet_v1()
# ratios = {}
# for param in net.parameters():
# if len(param.shape) == 4:
# ratios[param.name] = 0.5
# pruners = []
# pruner = L1NormFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
# pruner = FPGMFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
# pruner = L2NormFilterPruner(net, [1, 3, 128, 128])
# pruners.append(pruner)
#
# shapes = {}
# for pruner in pruners:
# plan = pruner.prune_vars(ratios, 0)
# for param in net.parameters():
# if param.name not in shapes:
# shapes[param.name] = param.shape
# assert(shapes[param.name] == param.shape)
# pruner.restore()
def add_cases(suite):
suite.addTest(TestStatus())
suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"]))
# suite.addTest(TestStatus())
# suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"]))
suite.addTest(TestPruningGroupConv2d())
def load_tests(loader, standard_tests, pattern):
......
......@@ -43,7 +43,7 @@ class TestPrune(unittest.TestCase):
paddle.disable_static()
model = net(pretrained=False)
pruner = L1NormFilterPruner(model, [1, 3, 16, 16])
pruner.prune_vars(ratios, [0])
pruner.prune_vars(ratios, 0)
shapes = {}
for param in model.parameters():
shapes[param.name] = param.shape
......
......@@ -25,7 +25,7 @@ class TestWalker(unittest.TestCase):
net = Net()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0)
self.assertTrue(net.linear.weight.shape == [5400, 5])
......
......@@ -8,14 +8,14 @@ from paddleslim.dygraph.prune.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase):
def testAdd(self):
plan = PruningPlan()
mask = PruningMask([0], [0, 1, 1], 0.33)
mask = PruningMask(0, [0, 1, 1], 0.33, None)
plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33)
mask = PruningMask(0, [0, 1, 0], 0.33, None)
plan.add("a", mask)
a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 0])
self.assertTrue(a_mask[0].dims == [0])
self.assertTrue(a_mask[0].dims == 0)
if __name__ == '__main__':
......
......@@ -16,7 +16,7 @@ sys.path.append("../")
import unittest
import paddle.fluid as fluid
from layers import conv_bn_layer
from paddleslim.prune import collect_convs
from paddleslim.prune import StaticPruningCollections
from static_case import StaticCase
......@@ -41,12 +41,9 @@ class TestPrune(StaticCase):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
collected_groups = collect_convs(
collections = StaticPruningCollections(
["conv1_weights", "conv2_weights", "conv3_weights", "dummy"],
main_program)
while [] in collected_groups:
collected_groups.remove([])
print(collected_groups)
params = set([
param.name for param in main_program.all_parameters()
......@@ -58,14 +55,13 @@ class TestPrune(StaticCase):
('conv4_weights', 0), ('conv5_weights', 1)],
[('conv3_weights', 0), ('conv4_weights', 1)]]
self.assertTrue(len(collected_groups) == len(expected_groups))
for _collected, _expected in zip(collected_groups, expected_groups):
for _name, _axis, _ in _collected:
self.assertTrue(len(collections._collections) == len(expected_groups))
for _collected, _expected in zip(collections, expected_groups):
for _info in _collected.all_pruning_details():
_name = _info.name
_axis = _info.axis
if _name in params:
self.assertTrue((_name, _axis) in _expected)
for _name, _axis in _expected:
if _name in params:
self.assertTrue((_name, _axis, []) in _collected)
if __name__ == '__main__':
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
sys.path.append("../")
import unittest
import numpy as np
......@@ -22,6 +23,7 @@ from static_case import StaticCase
from layers import conv_bn_layer
import random
from paddleslim.core import GraphWrapper
from paddleslim.prune.prune_worker import *
class TestPrune(StaticCase):
......@@ -35,53 +37,54 @@ class TestPrune(StaticCase):
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu')
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6')
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
flag = fluid.layers.fill_constant([1], value=1, dtype='int32')
rand_flag = paddle.randint(2, dtype='int32')
cond = fluid.layers.less_than(x=flag, y=rand_flag)
cond_output = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name='cond_output')
def cond_block1():
cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1")
fluid.layers.assign(input=cond_conv, output=cond_output)
def cond_block2():
cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1")
cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2")
fluid.layers.assign(input=cond_conv2, output=cond_output)
fluid.layers.cond(cond, cond_block1, cond_block2)
sum3 = fluid.layers.sum([sum2, cond_output])
conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
sub1 = conv6 - sum3
mult = sub1 * sub1
conv7 = conv_bn_layer(
mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False)
floored = fluid.layers.floor(conv7)
scaled = fluid.layers.scale(floored)
concated = fluid.layers.concat([scaled, mult], axis=1)
conv8 = conv_bn_layer(concated, 8, 3, "conv8")
predict = fluid.layers.fc(input=conv8, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
avg_cost = fluid.layers.mean(cost)
adam_optimizer.minimize(avg_cost)
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
conv2 = conv_bn_layer(conv1, 8, 3, "conv2", act='leaky_relu')
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3", act='relu6')
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
flag = fluid.layers.fill_constant([1], value=1, dtype='int32')
rand_flag = paddle.randint(2, dtype='int32')
cond = fluid.layers.less_than(x=flag, y=rand_flag)
cond_output = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name='cond_output')
def cond_block1():
cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1")
fluid.layers.assign(input=cond_conv, output=cond_output)
def cond_block2():
cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1")
cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2")
fluid.layers.assign(input=cond_conv2, output=cond_output)
fluid.layers.cond(cond, cond_block1, cond_block2)
sum3 = fluid.layers.sum([sum2, cond_output])
conv6 = conv_bn_layer(sum3, 8, 3, "conv6")
sub1 = conv6 - sum3
mult = sub1 * sub1
conv7 = conv_bn_layer(
mult, 8, 3, "Depthwise_Conv7", groups=8, use_cudnn=False)
floored = fluid.layers.floor(conv7)
scaled = fluid.layers.scale(floored)
concated = fluid.layers.concat([scaled, mult], axis=1)
conv8 = conv_bn_layer(concated, 8, 3, "conv8")
predict = fluid.layers.fc(input=conv8, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
avg_cost = fluid.layers.mean(cost)
adam_optimizer.minimize(avg_cost)
params = []
for param in main_program.all_parameters():
......@@ -117,5 +120,439 @@ class TestPrune(StaticCase):
fetch_list=[cost.name])
class TestUnsqueeze2(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1", act='relu')
out = paddle.unsqueeze(conv1, axis=[0])
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("unsqueeze2")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 2, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning in
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(in_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {})
class TestSqueeze2(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
out = paddle.squeeze(conv1)
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("squeeze2")
out_var = graph.var(out.name)
in_var = graph.var(conv1.name)
op = out_var.inputs()[0]
# pruning out
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(out_var, 0, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {
'conv1_weights': 0,
'conv1_bn_scale': 0,
'conv1_bn_offset': 0,
'conv1_bn_mean': 0,
'conv1_bn_variance': 0
})
# pruning in
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(in_var, 1, [])
for var, axis, _, _ in pruned_params:
ret[var.name()] = axis
self.assertTrue(ret == {})
class TestUnsupportAndDefault(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 8, 3, "conv1", act='relu') #[1, 8, 1, 1]
# hit default pruning worker
cast1 = paddle.cast(conv1, dtype="int32")
# hit unsupported pruning worker
out = paddle.reshape(cast1, shape=[1, -1])
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("conv2d")
in_var = graph.var("conv1_weights")
op = in_var.outputs()[0]
# pruning input of conv op
pruned_params = []
ret = {}
os.environ['OPS_UNSUPPORTED'] = "reshape2"
worker = cls(op, pruned_params, {}, True)
hit_unsupported_op = False
try:
worker.prune(in_var, 0, [])
except UnsupportOpError as e:
hit_unsupported_op = True
print(e)
self.assertTrue(hit_unsupported_op)
class TestConv2d(StaticCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[1, 3, 16, 16])
conv1 = conv_bn_layer(
input, 6, 3, "conv1", groups=1, bias=True, act='relu')
graph = GraphWrapper(main_program)
cls = PRUNE_WORKER.get("conv2d")
weight_var = graph.var("conv1_weights")
in_var = graph.var("image")
op = in_var.outputs()[0]
out_var = op.outputs("Output")[0]
# pruning weights of conv op
pruned_params = []
ret = {}
worker = cls(op, pruned_params, {}, True)
worker.prune(weight_var, 0, [])
worker.prune(weight_var, 1, [])
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == {
'conv1_weights': [0, 1],
'conv1_out.b_0': [0],
'conv1_bn_scale': [0],
'conv1_bn_offset': [0],
'conv1_bn_mean': [0],
'conv1_bn_variance': [0]
})
# pruning out of conv op
pruned_params = []
ret = {}
worker = cls(op, pruned_params, visited={}, skip_stranger=True)
worker.prune(out_var, 1, [])
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == {'conv1_weights': [0]})
class TestPruneWorker(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.create_graph()
self.cases = []
self.set_cases()
def define_layer(self, input):
pass
def set_cases(self):
pass
def create_graph(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with paddle.fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[8, 8, 16, 16])
self.define_layer(input)
self.graph = GraphWrapper(main_program)
self.in_var = self.graph.var(self.input.name)
self.out_var = self.graph.var(self.output.name)
self.op = self.in_var.outputs()[0]
def check_in_out(self):
cls = PRUNE_WORKER.get(self.op.type())
if cls is None:
cls = PRUNE_WORKER.get("default_worker")
# pruning input of conv op
for _var, _axis, _ret in self.cases:
pruned_params = []
ret = {}
worker = cls(self.op, pruned_params, visited={}, skip_stranger=True)
try:
worker.prune(_var, _axis, [])
except UnsupportOpError as e:
print(e)
continue
for var, axis, _, _ in pruned_params:
if var.name() not in ret:
ret[var.name()] = []
ret[var.name()].append(axis)
self.assertTrue(ret == _ret)
class TestConv2dTranspose(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d_transpose(
input, 6, 16, 3, name="conv1", bias_attr=False)
self.output = conv1
return conv1
def set_cases(self):
self.cases.append((self.in_var, 1, {'conv1.w_0': [0]}))
self.cases.append((self.out_var, 1, {'conv1.w_0': [1]}))
def test_prune(self):
self.check_in_out()
class TestElementwiseMul(TestPruneWorker):
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False)
conv2 = paddle.static.nn.conv2d(
input, 3, 3, name="conv2", bias_attr=False)
self.input = conv1
out = conv1 * conv2
conv3 = paddle.static.nn.conv2d(
out, 3, 3, name="conv3", bias_attr=False)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {
'conv2.tmp_0': [1],
'conv2.w_0': [0],
'conv3.w_0': [1]
}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'conv2.tmp_0': [1],
'conv2.w_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestActivation(TestPruneWorker):
def __init__(self, methodName="test_prune",
op=paddle.nn.functional.sigmoid):
super(TestActivation, self).__init__(methodName)
self.act = op
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False)
self.input = conv1
tmp = self.act(conv1)
self.output = tmp
conv2 = paddle.static.nn.conv2d(
tmp, 3, 3, name="conv2", bias_attr=False)
def set_cases(self):
self.cases.append((self.in_var, 1, {'conv2.w_0': [1]}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'conv2.w_0': [1]
}))
def test_prune(self):
self.check_in_out()
suite = unittest.TestSuite()
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_bilinear))
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_nearest))
suite.addTest(TestActivation(op=paddle.floor))
suite.addTest(TestActivation(op=paddle.scale))
suite.addTest(
TestActivation(op=paddle.fluid.layers.nn.uniform_random_batch_size_like))
class TestDepthwiseConv2d(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestDepthwiseConv2d, self).__init__(methodName)
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input,
input.shape[1],
3,
groups=input.shape[1],
name="conv1",
bias_attr=False)
self.output = conv1
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]}))
self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]}))
self.cases.append((weight_var, 0, {'conv1.w_0': [1]}))
def test_prune(self):
self.check_in_out()
class TestMul(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestMul, self).__init__(methodName)
def define_layer(self, input):
x = fluid.data(name="x", shape=[1, 4, 3, 3])
y = fluid.data(name="y", shape=[36, 7])
self.input = x
out = paddle.fluid.layers.mul(x, y)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]}))
def test_prune(self):
self.check_in_out()
class TestMatmul(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestMatmul, self).__init__(methodName)
def define_layer(self, input):
x = fluid.data(name="x", shape=[6, 8])
y = fluid.data(name="y", shape=[8, 7])
self.input = x
out = paddle.matmul(x, y)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0]}))
self.cases.append((self.out_var, 1, {'y': [1]}))
def test_prune(self):
self.check_in_out()
class TestSplit(TestPruneWorker):
def define_layer(self, input):
self.input = input
split1 = paddle.split(input, num_or_sections=2, axis=1, name=None)
self.output = split1[0]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 0, {}))
self.cases.append((self.out_var, 1, {}))
self.cases.append((self.out_var, 0, {}))
def test_prune(self):
self.check_in_out()
class TestMomentum(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Momentum()
opt.minimize(out)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_velocity_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAdam(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Adam()
opt.minimize(out)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_moment1_0': [0],
'conv1.w_0_moment2_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAffineChannel(TestPruneWorker):
def __init__(self, methodName="test_prune"):
super(TestAffineChannel, self).__init__(methodName)
def define_layer(self, input):
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.input = conv1
scale = fluid.data(name="scale", shape=[conv1.shape[1]])
bias = fluid.data(name="bias", shape=[conv1.shape[1]])
out = paddle.fluid.layers.affine_channel(conv1, scale=scale, bias=bias)
self.output = out
def set_cases(self):
self.cases.append((self.in_var, 1, {'scale': [0], 'bias': [0]}))
self.cases.append((self.out_var, 1, {
'conv1.w_0': [0],
'scale': [0],
'bias': [0]
}))
def test_prune(self):
self.check_in_out()
if __name__ == '__main__':
unittest.main()
......@@ -107,6 +107,7 @@ class TestSensitivity(StaticCase):
sensitivities_file="./sensitivities_file_2",
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
self.assertTrue(params_sens == origin_sens)
self.assertTrue(sens == origin_sens)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册