未验证 提交 05d8846b 编写于 作者: W whs 提交者: GitHub

Refine pruning code for detection (#554)

上级 c83eedc9
......@@ -14,7 +14,7 @@
import paddle
import numpy as np
import paddle.jit as jit
from ..core import GraphWrapper
from ..core import GraphWrapper, dygraph2program
__all__ = ["flops", "dygraph_flops"]
......@@ -83,11 +83,27 @@ def _graph_flops(graph, only_conv=True, detail=False):
return flops
def dygraph_flops(model, input_shape, only_conv=False, detail=False):
def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
"""
Compute the FLOPs of nn.Layer.
Args:
model(nn.Layer): The target model.
inputs(list): The dummy inputs used for 'model.forward'. It can be:
1. list<int>|tuple<int>: means 'model.forward' accepts
only one variable as argument and the shape of
variable is 'inputs'.
2. list<list<list>>: means 'model.forward' accepts multiple
variables as arguments and the shapes of variables is 'inputs'.
3. others: 'inputs' will be used as argument list by calling
'model.forward(*inputs)'.
dtypes(str|list<str>): It only used when 'inputs' is shape or shapes that means
data type of each input. None means all the inputs is 'float32'.
Default: None.
only_conv(bool): Just return number of mul-adds in convolution and FC layer if `only_conv` is true.
default: True.
detail(bool): Whether to return detail of each convolution layer.
"""
data = np.ones(tuple(input_shape)).astype("float32")
in_var = paddle.to_tensor(data)
_, traced = paddle.jit.TracedLayer.trace(model, [in_var])
program = traced.program
program = dygraph2program(model, inputs)
graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail)
......@@ -12,7 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper
from .registry import Registry
from ..core import graph_wrapper
from .graph_wrapper import *
from ..core import registry
from .registry import *
from ..core import dygraph
from .dygraph import *
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry']
__all__ = []
__all__ += graph_wrapper.__all__
__all__ += registry.__all__
__all__ += dygraph.__all__
import paddle
import collections
import logging
import numpy as np
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard
from paddle.fluid.dygraph.base import program_desc_tracing_guard
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from ..common import get_logger
__all__ = ["dygraph2program"]
_logger = get_logger(__name__, level=logging.INFO)
def _is_shape(values):
if not isinstance(values, (list, tuple)):
return False
for v in values:
if not isinstance(v, int):
return False
return True
def _is_shapes(values):
if not isinstance(values, (list, tuple)):
return False
for v in values:
if not _is_shape(v):
return False
return True
def _create_tensors(shapes, dtypes=None):
if dtypes is not None:
assert len(shapes) == len(
dtypes
), "Length of shapes and dtypes must be same. But get len(shapes): {}; len(dtypes): {}; shapes: {}; dtypes: {}".format(
len(shapes), len(dtypes), shapes, dtypes)
else:
dtypes = len(shapes) * ['float32']
tensors = []
for shape, dtype in zip(shapes, dtypes):
data = np.ones(tuple(shape)).astype(dtype)
tensors.append(paddle.to_tensor(data))
return tensors
def extract_vars(inputs):
"""
Extract a list of variables from inputs.
Args:
inputs(Variable | list<Object> | dict):
"""
vars = []
if isinstance(inputs, Variable):
vars = [inputs]
elif isinstance(inputs, dict):
for _key, _value in inputs.items():
if isinstance(_value, Variable):
vars.append(_value)
else:
_logger.warn(
f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}."
)
elif isinstance(inputs, (tuple, list)):
for _value in inputs:
vars.extend(extract_vars(_value))
if len(vars) == 0:
_logger.warn(f"Extract none variables from inputs.")
return vars
def to_variables(inputs):
"""
Find and rename variables. Find np.ndarray and convert it to variable.
"""
if isinstance(inputs, Variable) or isinstance(inputs, np.ndarray):
return paddle.fluid.dygraph.to_variable(inputs)
elif isinstance(inputs, dict):
ret = {}
for _key in inputs:
ret[_key] = to_variables(inputs[_key])
return inputs
elif isinstance(inputs, list):
ret = []
for _value in inputs:
ret.append(to_variables(_value))
return ret
@dygraph_only
def dygraph2program(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_',
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
tracer = _dygraph_tracer()._get_program_desc_tracer()
with program_desc_tracing_guard(True):
if _is_shape(inputs):
shapes = [inputs]
inputs = _create_tensors(shapes, dtypes=dtypes)
input_var_list = inputs
elif _is_shapes(inputs):
inputs = _create_tensors(inputs, dtypes=dtypes)
input_var_list = inputs
else:
inputs = to_variables(inputs)
input_var_list = extract_inputs_fn(inputs)
original_outputs = layer(*inputs)
# 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'.
out_var_list = extract_outputs_fn(original_outputs)
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
input_var_list, feed_prefix, out_var_list, fetch_prefix, tmp_prefix)
tracer.reset()
with _dygraph_guard(None):
program = Program()
program.desc = program_desc
program.blocks = [Block(program, 0)]
program._sync_with_cpp()
return program
......@@ -397,9 +397,3 @@ class GraphWrapper(object):
# Infer the remain ops in topological order.
for op in head_op:
recursive_infer(op, infer=True)
def update_groups_of_conv(self):
for op in self.ops():
if 'conv2d' in op.type() and op.attr('groups') >= op.inputs(
'Filter')[0].shape()[0]:
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
......@@ -47,24 +47,34 @@ class FilterPruner(Pruner):
Args:
model(paddle.nn.Layer): The target model to be pruned.
input_shape(list<int>): The input shape of model. It is used to trace the graph of the model.
inputs(list<int>): The inputs of model. It will be use in calling 'model.forward(inputs)'.
sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is
set rightly, 'FilterPruner::sensitive' function can not be called anymore
in next step. Default: None.
"""
def __init__(self, model, input_shape, sen_file=None):
super(FilterPruner, self).__init__(model, input_shape)
def __init__(self, model, inputs, sen_file=None):
super(FilterPruner, self).__init__(model, inputs)
self._status = Status(sen_file)
# sensitive and var_group are just used in filter pruning
self.var_group = VarGroup(model, input_shape)
self.var_group = VarGroup(model, inputs)
# skip vars in:
# 1. depthwise conv2d layer
self.skip_vars = []
for sub_layer in model.sublayers():
if isinstance(
sub_layer,
paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1:
for param in sub_layer.parameters():
self.skip_vars.append(param.name)
def sensitive(self,
eval_func=None,
sen_file=None,
target_vars=None,
skip_vars=None):
skip_vars=[]):
"""
Compute or get sensitivities of model in current pruner. It will return a cached sensitivities when all the arguments are "None".
......@@ -88,7 +98,7 @@ class FilterPruner(Pruner):
eval_func(function, optional): The function to evaluate the model in current pruner. This function should have an empy arguments list and return a score with type "float32". Default: None.
sen_file(str, optional): The absolute path of file to save sensitivities into local filesystem. Default: None.
target_vars(list, optional): The names of tensors whose sensitivity will be computed. "None" means all weights in convolution layer will be computed. Default: None.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. "None" means skip nothing. Default: None.
skip_vars(list, optional): The names of tensors whose sensitivity won't be computed. Default: [].
Returns:
dict: A dict storing sensitivities.
......@@ -102,6 +112,7 @@ class FilterPruner(Pruner):
if not self._status.is_ckp:
return self._status
skip_vars.extend(self.skip_vars)
self._cal_sensitive(
self.model,
eval_func,
......@@ -186,9 +197,9 @@ class FilterPruner(Pruner):
Returns:
tuple: A tuple with format ``(ratios, pruned_flops)`` . "ratios" is a dict whose key is name of tensor and value is ratio to be pruned. "pruned_flops" is the ratio of total pruned FLOPs in the model.
"""
base_flops = flops(self.model, self.input_shape)
base_flops = flops(self.model, self.inputs)
_logger.info("Base FLOPs: {}".format(base_flops))
_logger.debug("Base FLOPs: {}".format(base_flops))
low = 0.
up = 1.0
history = set()
......@@ -200,8 +211,7 @@ class FilterPruner(Pruner):
if align is not None:
ratios = self._round_to(ratios, dims=dims, factor=align)
plan = self.prune_vars(ratios, axis=dims)
_logger.debug("pruning plan: {}".format(plan))
c_flops = flops(self.model, self.input_shape)
c_flops = flops(self.model, self.inputs)
_logger.debug("FLOPs after pruning: {}".format(c_flops))
c_pruned_flops = (base_flops - c_flops) / base_flops
plan.restore(self.model)
......@@ -304,7 +314,11 @@ class FilterPruner(Pruner):
plan: An instance of PruningPlan that can be applied on model by calling 'plan.apply(model)'.
"""
if var_name in self.skip_vars:
_logger.warn(
f"{var_name} is skiped beacause it is not support for pruning derectly."
)
return
if isinstance(pruned_dims, int):
pruned_dims = [pruned_dims]
group = self.var_group.find_group(var_name, pruned_dims)
......@@ -315,29 +329,52 @@ class FilterPruner(Pruner):
for param in sub_layer.parameters(include_sublayers=False):
if param.name in group:
group_dict[param.name] = group[param.name]
group_dict[param.name].update({
'layer': sub_layer,
'var': param,
'value': np.array(param.value().get_tensor())
})
# Varibales can be pruned on multiple axies.
for _item in group_dict[param.name]:
_item.update({
'layer': sub_layer,
'var': param,
'value': np.array(param.value().get_tensor())
})
_logger.debug(f"set value of {param.name} into group")
mask = self.cal_mask(var_name, pruned_ratio, group_dict)
for _name in group_dict:
dims = group_dict[_name]['pruned_dims']
stride = group_dict[_name]['stride']
var_shape = group_dict[_name]['var'].shape
if isinstance(dims, int):
dims = [dims]
current_mask = mask.repeat(stride[0]) if stride[0] > 1 else mask
assert len(current_mask) == var_shape[dims[
0]], "The length of current_mask must be equal to the size of dimension to be pruned on."
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
# Varibales can be pruned on multiple axies.
for _item in group_dict[_name]:
dims = _item['pruned_dims']
transforms = _item['transforms']
var_shape = _item['var'].shape
if isinstance(dims, int):
dims = [dims]
for trans in transforms:
mask = self._transform_mask(mask, trans)
current_mask = mask
assert len(current_mask) == var_shape[dims[
0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
if apply == "lazy":
plan.apply(self.model, lazy=True)
elif apply == "impretive":
plan.apply(self.model, lazy=False)
return plan
def _transform_mask(self, mask, transform):
src_start = transform['src_start']
src_end = transform['src_end']
target_start = transform['target_start']
target_end = transform['target_end']
target_len = transform['target_len']
stride = transform['stride']
mask = mask[src_start:src_end]
mask = mask.repeat(stride) if stride > 1 else mask
dst_mask = np.ones([target_len])
# for depthwise conv2d with:
# input shape: (1, 4, 32, 32)
# filter shape: (32, 1, 3, 3)
# groups: 4
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand
return dst_mask
......@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO)
class FPGMFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None):
super(FPGMFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file)
def __init__(self, model, inputs, sen_file=None):
super(FPGMFilterPruner, self).__init__(model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value']
pruned_dims = group[var_name]['pruned_dims']
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
assert (pruned_dims == [0])
dist_sum_list = []
......
......@@ -12,13 +12,15 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None):
def __init__(self, model, inputs, sen_file=None):
super(L1NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file)
model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value']
pruned_dims = group[var_name]['pruned_dims']
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims
]
......
......@@ -12,13 +12,16 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner):
def __init__(self, model, input_shape, sen_file=None):
def __init__(self, model, inputs, sen_file=None):
super(L2NormFilterPruner, self).__init__(
model, input_shape, sen_file=sen_file)
model, inputs, sen_file=sen_file)
def cal_mask(self, var_name, pruned_ratio, group):
value = group[var_name]['value']
pruned_dims = group[var_name]['pruned_dims']
# find information of pruning on output channels
for _item in group[var_name]:
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
reduce_dims = [
i for i in range(len(value.shape)) if i not in pruned_dims
]
......
......@@ -19,9 +19,9 @@ class Pruner(object):
"""
def __init__(self, model, input_shape):
def __init__(self, model, inputs):
self.model = model
self.input_shape = input_shape
self.inputs = inputs
self._var_shapes = {}
for var in model.parameters():
self._var_shapes[var.name] = var.shape
......@@ -53,5 +53,5 @@ class Pruner(object):
global_plan.apply(self.model, lazy=True)
elif apply == "impretive":
global_plan.apply(self.model, lazy=False)
self.plan = global_plan
return global_plan
......@@ -28,7 +28,7 @@ class PruningMask():
if self._mask is not None:
assert len(self._mask.shape) == len(
value
), "The length of value must be same with shape of mask in current PruningMask instance."
), "The length of value must be same with length of mask's shape in current PruningMask instance."
self._dims = list(value)
@property
......@@ -37,11 +37,6 @@ class PruningMask():
@mask.setter
def mask(self, value):
assert (isinstance(value, PruningMask))
if self._dims is not None:
assert len(self._mask.shape) == len(
value
), "The length of value must be same with shape of mask in current PruningMask instance."
self._mask = value
def __str__(self):
......@@ -71,13 +66,21 @@ class PruningPlan():
self._pruned_flops = value
def add(self, var_name, pruning_mask):
assert (isinstance(pruning_mask, PruningMask))
if var_name not in self._masks:
self._masks[var_name] = []
self._masks[var_name].append(pruning_mask)
if var_name not in self._dims:
self._dims[var_name] = []
self._dims[var_name].append(pruning_mask.dims)
if pruning_mask.dims in self._dims[var_name]:
for _mask in self._masks[var_name]:
if pruning_mask.dims == _mask.dims:
_mask.mask = list(
np.array(_mask.mask) | np.array(pruning_mask.mask))
else:
self._masks[var_name].append(pruning_mask)
self._dims[var_name].append(pruning_mask.dims)
@property
def masks(self):
......@@ -87,8 +90,7 @@ class PruningPlan():
assert (isinstance(plan, PruningPlan))
for var_name in plan.masks:
for mask in plan.masks[var_name]:
if not self.contains(var_name, mask.dims):
self.add(var_name, mask)
self.add(var_name, mask)
def contains(self, var_name, dims=None):
return (var_name in self._dims) and (dims is None or
......@@ -172,7 +174,6 @@ class PruningPlan():
bool_mask = mask.astype(bool)
pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value)
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
......@@ -184,14 +185,19 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place)
if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D):
if sub_layer._groups > 1 and pruned_value.shape[
1] == 1: # depthwise conv2d
_logger.debug(
"Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups,
pruned_value.shape[0]))
sub_layer._groups = pruned_value.shape[0]
if isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D
) and sub_layer._groups > 1 and len(param.shape) == 4:
assert param.shape[
1] == 1, "It just supports depthwise conv2d when groups > 1."
new_groups = int(bool_mask.sum() *
sub_layer._groups / len(bool_mask))
_logger.debug(
"Update groups of depthwise conv2d form {} to {}".
format(sub_layer._groups, new_groups))
sub_layer._origin_groups = sub_layer._groups
sub_layer._groups = new_groups
# for training
if param.trainable:
param.clear_gradient()
......@@ -218,11 +224,6 @@ class PruningPlan():
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(np.array(t_backup).astype("float32"), place)
if isinstance(sub_layer, paddle.nn.layer.conv.Conv2D):
if sub_layer._groups > 1:
_logger.debug(
"Update groups of conv form {} to {}".format(
sub_layer._groups, t_value.shape()[0]))
sub_layer._groups = t_value.shape()[0]
if "_origin_groups" in sub_layer.__dict__:
sub_layer._groups = sub_layer._origin_groups
del sub_layer._buffers[backup_name]
......@@ -2,7 +2,7 @@ import numpy as np
import logging
import paddle
from paddle.fluid.dygraph import TracedLayer
from ..core import GraphWrapper
from ..core import GraphWrapper, dygraph2program
from ..prune import collect_convs
from ..common import get_logger
......@@ -12,33 +12,43 @@ _logger = get_logger(__name__, level=logging.INFO)
class VarGroup():
def __init__(self, model, input_shape):
"""
A tool used to parse dygraph and store information of variables' relationship.
Args:
- model(nn.Layer): The dygraph to be parsed.
- inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`.
"""
def __init__(self, model, inputs):
self.groups = []
self._parse_model(model, input_shape)
self._parse_model(model, inputs)
def _to_dict(self, group):
ret = {}
for _name, _axis, _stride in group:
for _name, _axis, _transforms in group:
if isinstance(_axis, int):
_axis = [_axis] # TODO: fix
ret[_name] = {'pruned_dims': _axis, 'stride': _stride}
_axis = [_axis]
if _name not in ret:
ret[_name] = []
# Variable can be pruned on multiple axies.
ret[_name].append({'pruned_dims': _axis, 'transforms': _transforms})
return ret
def find_group(self, var_name, axis):
for group in self.groups:
for _name, _axis, _stride in group:
if isinstance(_axis, int):
_axis = [_axis] # TODO: fix
_axis = [_axis]
if _name == var_name and _axis == axis:
return self._to_dict(group)
def _parse_model(self, model, input_shape):
_logger.debug("Parsing model with input: {}".format(input_shape))
data = np.ones(tuple(input_shape)).astype("float32")
in_var = paddle.to_tensor(data)
def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(inputs))
model.eval()
out_dygraph, static_layer = TracedLayer.trace(model, inputs=[in_var])
graph = GraphWrapper(static_layer.program)
program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(program)
visited = {}
for name, param in model.named_parameters():
......
......@@ -79,7 +79,7 @@ def collect_convs(params, graph, visited={}):
pruned_params=pruned_params,
visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0])
walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params)
visited = set()
uniq_groups = []
......@@ -96,5 +96,4 @@ def collect_convs(params, graph, visited={}):
simple_group.append((param, axis, pruned_idx))
if not repeat_group:
uniq_groups.append(simple_group)
return uniq_groups
......@@ -59,11 +59,9 @@ def default_idx_selector(group, ratio):
pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num]
idxs = []
for name, axis, score, offsets in group:
r_idx = [i + offsets[0] for i in pruned_idx]
idxs.append((name, axis, r_idx))
for name, axis, score, transforms in group:
idxs.append((name, axis, pruned_idx, transforms))
return idxs
......@@ -112,6 +110,6 @@ def optimal_threshold(group, ratio):
pruned_idx = np.squeeze(np.argwhere(score < th))
idxs = []
for name, axis, score, _ in group:
idxs.append((name, axis, pruned_idx))
for name, axis, score, transforms in group:
idxs.append((name, axis, pruned_idx, transforms))
return idxs
......@@ -10,6 +10,7 @@ __all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO)
_SHAPES_FILE = "__shapes__"
_GROUPS_FILE = "__groups__"
def save_model(exe, graph, dirname):
......@@ -39,6 +40,17 @@ def save_model(exe, graph, dirname):
json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
groups = {}
for op in graph.ops():
if 'conv2d' in op.type():
filter_name = op.inputs('Filter')[0].name()
groups[filter_name] = op.attr('groups')
GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE)
with open(GROUPS_FILE, "w") as f:
json.dump(groups, f)
_logger.info("Save groups of cnov2d into {}".format(GROUPS_FILE))
def load_model(exe, graph, dirname):
"""
......@@ -53,7 +65,6 @@ def load_model(exe, graph, dirname):
paddle.static.Program) else graph
SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f:
shapes = json.load(f)
for param_name, shape in shapes.items():
......@@ -62,9 +73,17 @@ def load_model(exe, graph, dirname):
param.set_shape(shape)
else:
_logger.info('{} is not loaded'.format(param_name))
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
GROUPS_FILE = os.path.join(dirname, _GROUPS_FILE)
with open(GROUPS_FILE, "r") as f:
groups = json.load(f)
for op in graph.ops():
if 'conv2d' in op.type():
filter_name = op.inputs('Filter')[0].name()
op.set_attr('groups', groups[filter_name])
_logger.info("Load groups of conv2d from {}".format(GROUPS_FILE))
paddle.static.load(program=graph.program, model_path=dirname, executor=exe)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(dirname))
......@@ -65,6 +65,15 @@ class PruneWorker(object):
self.visited[pruned_axis][key] = True
return True
def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis)
pre_ops = var.inputs()
for op in pre_ops:
self._prune_op(op, var, axis, transforms)
next_ops = var.outputs()
for op in next_ops:
self._prune_op(op, var, axis, transforms)
def _prune(self, var, pruned_axis, pruned_idx):
raise NotImplementedError('Abstract method.')
......@@ -85,6 +94,9 @@ class PruneWorker(object):
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
_logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
)
walker = cls(op, pruned_params=self.pruned_params, visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
......@@ -170,11 +182,6 @@ class conv2d(PruneWorker):
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
@PRUNE_WORKER.register
class conv2d_transpose(PruneWorker):
......@@ -250,6 +257,12 @@ class batch_norm(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class sync_batch_norm(batch_norm):
def __init__(self, op, pruned_params, visited):
super(sync_batch_norm, self).__init__(op, pruned_params, visited)
class elementwise_op(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(elementwise_op, self).__init__(op, pruned_params, visited)
......@@ -269,9 +282,12 @@ class elementwise_op(PruneWorker):
in_var = self.op.inputs(name)[0]
if len(in_var.shape()) == 1 and in_var.shape()[0] == 1:
continue
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)
# for bias
if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.pruned_params.append((in_var, actual_axis, pruned_idx))
self._visit_and_search(in_var, actual_axis, pruned_idx)
else:
if var in self.op.inputs("X"):
......@@ -287,24 +303,17 @@ class elementwise_op(PruneWorker):
in_var.shape()[0] == 1):
self.pruned_params.append(
(in_var, y_pruned_axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, y_pruned_axis, pruned_idx)
self._visit_and_search(in_var, y_pruned_axis, pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
if len(in_var.shape()) != len(var.shape()):
assert (len(var.shape()) < len(in_var.shape()))
pruned_axis = pruned_axis + axis
if pruned_axis <= len(in_var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -447,53 +456,70 @@ class concat(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(concat, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
idx = []
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
self._visit(var, pruned_axis)
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
for i in pruned_idx:
r_idx = i - start
if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0:
idx.append(r_idx)
transoform = {
'src_start': start,
'src_end': start + in_var.shape()[pruned_axis],
'target_start': 0,
'target_end': in_var.shape()[pruned_axis],
'target_len': in_var.shape()[pruned_axis],
'stride': 1
}
start += in_var.shape()[pruned_axis]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, idx)
idx = pruned_idx[:]
self._prune_op(op, in_var, pruned_axis,
transforms + [transoform])
else:
for _, in_var in enumerate(self.op.inputs("X")):
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._prune_op(op, in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"):
self._visit(var, pruned_axis)
if axis == pruned_axis:
idx = []
start = 0
target_start = 0
for v in self.op.inputs("X"):
if v.name() == var.name():
idx = [i + start for i in pruned_idx]
if v.name() != var.name():
target_start += v.shape()[pruned_axis]
else:
start += v.shape()[pruned_axis]
break
target_end = target_start + v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, idx, visited={})
else:
for v in self.op.inputs("X"):
for op in v.inputs():
self._prune_op(op, v, pruned_axis, pruned_idx)
out_var = self.op.outputs("Out")[0]
transform = {
'src_start': 0,
'src_end': var.shape()[pruned_axis],
'target_start': target_start,
'target_end': target_end,
'target_len': out_var.shape()[pruned_axis],
'stride': 1
}
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
# The output of concat can be visited repeatedly
c_visited = {}
self._prune_op(
op,
out_var,
pruned_axis,
transforms + [transform],
visited=c_visited)
# Add nodes searched from concat into global visited array.
self.visited.update(c_visited)
@PRUNE_WORKER.register
......@@ -501,8 +527,14 @@ class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
super(depthwise_conv2d, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
assert var not in self.op.inputs(
"Filter"), "Unsupport for pruning depthwise conv2d directly."
assert var not in self.op.outputs(
"Output"
), "Unsupport for pruning output of depthwise conv2d directly."
data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
channel_axis = 1
if data_format == "NHWC":
channel_axis = 3
......@@ -510,60 +542,28 @@ class depthwise_conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis)
groups = var.shape()[channel_axis]
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
transform = {
'src_start': 0,
'src_end': var.shape()[pruned_axis],
'target_start': 0,
'target_end': filter_var.shape()[0],
'target_len': filter_var.shape()[0],
'stride': 1
}
self.pruned_params.append((filter_var, 0, transforms + [transform]))
self._visit(filter_var, 0)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
self._prune_op(op, filter_var, 0, transforms + [transform])
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0]
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
self.pruned_params.append((var, 0, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, 0, pruned_idx)
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis
filter_var = self.op.inputs("Filter")[0]
self.pruned_params.append((filter_var, 0, pruned_idx))
self._visit(filter_var, 0)
for op in filter_var.outputs():
self._prune_op(op, filter_var, 0, pruned_idx)
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias")[0], channel_axis, pruned_idx))
in_var = self.op.inputs("Input")[0]
self._visit(in_var, channel_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, channel_axis, pruned_idx)
output_var = self.op.outputs("Output")[0]
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
self._prune_op(op, output_var, channel_axis,
transforms + [transform])
@PRUNE_WORKER.register
......@@ -679,7 +679,7 @@ class flatten_contiguous_range(PruneWorker):
super(flatten_contiguous_range, self).__init__(op, pruned_params,
visited)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis")
if var in self.op.inputs("X"):
......@@ -687,7 +687,6 @@ class flatten_contiguous_range(PruneWorker):
in_var = self.op.inputs("X")[0]
stride = 1
out_pruned_axis = pruned_axis
out_pruned_idx = pruned_idx
if pruned_axis >= start_axis and pruned_axis <= stop_axis:
out_pruned_axis = start_axis
for i in range(pruned_axis + 1, stop_axis + 1):
......@@ -697,7 +696,8 @@ class flatten_contiguous_range(PruneWorker):
self._visit(in_var, pruned_axis)
self._visit(out_var, out_pruned_axis)
transform = {'stride': stride}
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, out_pruned_axis, [stride])
self._prune_op(op, out_var, out_pruned_axis,
transforms + [transform])
......@@ -85,8 +85,8 @@ class Pruner():
param_backup = {} if param_backup else None
param_shape_backup = {} if param_shape_backup else None
visited = {}
pruned_params = []
visited = {}
for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param))
if graph.var(param) is None:
......@@ -98,28 +98,19 @@ class Pruner():
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
if only_graph and self.idx_selector.__name__ == "default_idx_selector":
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, axis, _ in group:
pruned_params.append((name, axis, pruned_idx))
else:
assert ((not self.pruned_weights),
"The weights have been pruned once.")
group_values = []
for name, axis, pruned_idx in group:
var = scope.find_var(name)
if var is not None:
values = np.array(var.get_tensor())
group_values.append((name, values, axis, pruned_idx))
scores = self.criterion(
group_values, graph) # [(name, axis, score, pruned_idx)]
pruned_params.extend(self.idx_selector(scores, ratio))
assert ((not self.pruned_weights),
"The weights have been pruned once.")
group_values = []
for name, axis, pruned_idx in group:
var = scope.find_var(name)
if var is not None:
values = np.array(var.get_tensor())
group_values.append((name, values, axis, pruned_idx))
scores = self.criterion(group_values,
graph) # [(name, axis, score, pruned_idx)]
g = self._transform(self.idx_selector(scores, ratio))
pruned_params.extend(g)
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
......@@ -128,7 +119,6 @@ class Pruner():
if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][
......@@ -138,12 +128,26 @@ class Pruner():
_logger.debug("{}\t{}\t{}\t{}".format(
param.name(), pruned_axis,
param.shape()[pruned_axis], len(pruned_idx)))
origin_shape = copy.deepcopy(param.shape())
if param_shape_backup is not None:
origin_shape = copy.deepcopy(param.shape())
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
param.set_shape(new_shape)
# update groups of depthwise conv2d
for op in param.outputs():
if op.type() in ["conv2d", "depthwise_conv2d"
] and op.attr("groups") > 1:
assert origin_shape[
1] == 1, "Only support for depthwise when groups > 1."
new_groups = int(
op.attr("groups") * new_shape[pruned_axis] /
origin_shape[pruned_axis])
_logger.debug(
f"change groups of conv({param.name()}) from {op.attr('groups')} to {new_groups}; origin_shape: {origin_shape}; new_shape: {new_shape}"
)
op.set_attr("groups", new_groups)
if not only_graph:
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
......@@ -156,16 +160,35 @@ class Pruner():
pruned_idx,
pruned_axis=pruned_axis,
lazy=lazy)
param_t.set(pruned_param, place)
except IndexError as e:
_logger.error("Pruning {}, but get [{}]".format(
param.name(), e))
param_t.set(pruned_param, place)
graph.update_groups_of_conv()
graph.infer_shape()
self.pruned_weights = (not only_graph)
return graph.program, param_backup, param_shape_backup
def _transform(self, group):
ret = []
for name, axis, pruned_idx, transforms in group:
src = pruned_idx
for trans in transforms:
src_start = trans['src_start']
src_end = trans['src_end']
target_start = trans['target_start']
target_end = trans['target_end']
target = []
for idx in src:
if idx >= src_start and idx < src_end:
idx -= src_start
idx += target_start
if idx < target_end:
target.append(idx)
src = target
ret.append((name, axis, src))
return ret
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
"""
Pruning a array by indexes on given axis.
......
import sys
sys.path.append("../../")
import unittest
import numpy as np
import paddle
from paddleslim.analysis import dygraph_flops as flops
from paddle.vision.models import mobilenet_v1, resnet50
from paddle.nn import Conv2D, Layer
class TestFlops(unittest.TestCase):
......@@ -17,9 +20,64 @@ class TestFlops(unittest.TestCase):
self.assertTrue(FLOPs == self._gt)
class Net1(Layer):
def __init__(self):
super(Net1, self).__init__()
self.conv1 = Conv2D(3, 2, 3)
self.conv2 = Conv2D(3, 2, 3)
def forward(self, inputs):
assert isinstance(inputs, dict)
x = inputs["x"]
y = inputs["y"]
return {"x": self.conv1(x), "y": self.conv2(y), "dummy": "dummy"}
class Net2(Net1):
def __init__(self):
super(Net2, self).__init__()
def forward(self, x, y):
return [self.conv1(x), self.conv2(y), "dummy"]
class TestFLOPsCase1(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
y_shape = (1, 3, 16, 16)
net = Net1()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = {
"x": paddle.to_tensor(x),
"y": paddle.to_tensor(y),
"z": "test"
}
FLOPs = flops(net, [inputs])
self.assertTrue(FLOPs == 59184)
class TestFLOPsCase2(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
y_shape = (1, 3, 16, 16)
net = Net2()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
y = np.random.uniform(-1, 1, y_shape).astype('float32')
inputs = [paddle.to_tensor(x), paddle.to_tensor(y)]
FLOPs1 = flops(net, inputs)
shapes = [x_shape, y_shape]
FLOPs2 = flops(net, shapes, dtypes=["float32", "float32"])
self.assertTrue(FLOPs1 == FLOPs2)
def add_cases(suite):
suite.addTest(TestFlops(net=mobilenet_v1, gt=11792896.0))
suite.addTest(TestFlops(net=resnet50, gt=83872768.0))
suite.addTest(TestFLOPsCase1())
suite.addTest(TestFLOPsCase2())
def load_tests(loader, standard_tests, pattern):
......
......@@ -47,6 +47,7 @@ class TestPrune(unittest.TestCase):
shapes = {}
for param in model.parameters():
shapes[param.name] = param.shape
pruner.restore()
return shapes
def static_prune(self, net, ratios):
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -23,7 +24,8 @@ def conv_bn_layer(input,
groups=1,
act=None,
bias=False,
use_cudnn=True):
use_cudnn=True,
sync_bn=False):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -37,11 +39,19 @@ def conv_bn_layer(input,
name=name + "_out",
use_cudnn=use_cudnn)
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
if sync_bn:
bn = paddle.nn.SyncBatchNorm(
num_filters,
weight_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
name=bn_name)
return bn(conv)
else:
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
import sys
sys.path.append("../")
import unittest
import numpy as np
from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase):
def testAdd(self):
plan = PruningPlan()
mask = PruningMask([0], [0, 0, 1], 0.33)
plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33)
plan.add("a", mask)
a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 1])
self.assertTrue(a_mask[0].dims == [0])
if __name__ == '__main__':
unittest.main()
......@@ -41,14 +41,30 @@ class TestPrune(StaticCase):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs(
collected_groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
while [] in groups:
groups.remove([])
print(groups)
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 20)
self.assertTrue(len(groups[1]) == 6)
while [] in collected_groups:
collected_groups.remove([])
print(collected_groups)
params = set([
param.name for param in main_program.all_parameters()
if "weights" in param.name
])
expected_groups = [[('conv1_weights', 0), ('conv2_weights', 1),
('conv2_weights', 0), ('conv3_weights', 1),
('conv4_weights', 0), ('conv5_weights', 1)],
[('conv3_weights', 0), ('conv4_weights', 1)]]
self.assertTrue(len(collected_groups) == len(expected_groups))
for _collected, _expected in zip(collected_groups, expected_groups):
for _name, _axis, _ in _collected:
if _name in params:
self.assertTrue((_name, _axis) in _expected)
for _name, _axis in _expected:
if _name in params:
self.assertTrue((_name, _axis, []) in _collected)
if __name__ == '__main__':
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
from static_case import StaticCase
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from static_case import StaticCase
from layers import conv_bn_layer
class TestPrune(StaticCase):
def test_concat(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X
# conv1 conv2-->concat conv3-->sum-->out
# | ^ | ^
# |____________| |____________________|
#
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(input, 8, 3, "conv2", sync_bn=True)
tmp = fluid.layers.concat([conv1, conv2], axis=1)
conv3 = conv_bn_layer(input, 16, 3, "conv3", bias=None)
out = conv3 + tmp
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv3_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv3_weights": (8, 3, 3, 3),
"conv2_weights": (4, 3, 3, 3),
"conv1_weights": (4, 3, 3, 3)
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
# test forward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv1_weights", "conv2_weights"],
ratios=[0.5, 0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4, 3, 3, 3),
"conv1_bn_scale": (4, ),
"conv1_bn_variance": (4, ),
"conv1_bn_mean": (4, ),
"conv1_bn_offset": (4, ),
"conv2_weights": (4, 3, 3, 3),
"sync_batch_norm_0.w_0": (4, ),
"sync_batch_norm_0.w_1": (4, ),
"conv2_bn_scale": (4, ),
"conv2_bn_offset": (4, ),
"conv3_weights": (8, 3, 3, 3),
"conv3_bn_mean": (8, ),
"conv3_bn_offset": (8, ),
"conv3_bn_scale": (8, ),
"conv3_bn_variance": (8, ),
"conv3_out.b_0": (8, ),
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册