未验证 提交 6924c977 编写于 作者: W whs 提交者: GitHub

[2.1.1]pruning mul op with 4-d inputs (#741)

上级 32318455
...@@ -53,14 +53,20 @@ class FilterPruner(Pruner): ...@@ -53,14 +53,20 @@ class FilterPruner(Pruner):
sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is sen_file(str, optional): The absolute path of file that stores computed sensitivities. If it is
set rightly, 'FilterPruner::sensitive' function can not be called anymore set rightly, 'FilterPruner::sensitive' function can not be called anymore
in next step. Default: None. in next step. Default: None.
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs, sen_file=None, opt=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(FilterPruner, self).__init__(model, inputs, opt=opt) super(FilterPruner, self).__init__(model, inputs, opt=opt)
self._status = Status(sen_file) self._status = Status(sen_file)
self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning # sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections(model, inputs) self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves)
# skip vars in: # skip vars in:
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
self.skip_vars = [] self.skip_vars = []
...@@ -364,6 +370,8 @@ class FilterPruner(Pruner): ...@@ -364,6 +370,8 @@ class FilterPruner(Pruner):
stride = transform['stride'] stride = transform['stride']
mask = mask.repeat(stride) if stride > 1 else mask mask = mask.repeat(stride) if stride > 1 else mask
return mask return mask
elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
else: else:
return mask return mask
return dst_mask return dst_mask
...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class FPGMFilterPruner(FilterPruner): class FPGMFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(FPGMFilterPruner, self).__init__( super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner): class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,9 +12,10 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None): def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name var_name = collection.master_name
......
...@@ -208,7 +208,8 @@ class PruningPlan(): ...@@ -208,7 +208,8 @@ class PruningPlan():
Pruning values of variable imperatively. It is valid when pruning Pruning values of variable imperatively. It is valid when pruning
on one dimension. on one dimension.
""" """
for name, sub_layer in model.named_sublayers():
for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks: if param.name in self._masks:
for _mask in self._masks[param.name]: for _mask in self._masks[param.name]:
...@@ -218,7 +219,6 @@ class PruningPlan(): ...@@ -218,7 +219,6 @@ class PruningPlan():
bool_mask = np.array(mask).astype(bool) bool_mask = np.array(mask).astype(bool)
t_value = param.value().get_tensor() t_value = param.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
groups = _mask._op.attr('groups') groups = _mask._op.attr('groups')
if dims == 1 and groups is not None and groups > 1 and len( if dims == 1 and groups is not None and groups > 1 and len(
value.shape) == 4: value.shape) == 4:
...@@ -262,7 +262,7 @@ class PruningPlan(): ...@@ -262,7 +262,7 @@ class PruningPlan():
param.clear_gradient() param.clear_gradient()
def restore(self, model, opt=None): def restore(self, model, opt=None):
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
# restore optimizer accumulators from layer buffer # restore optimizer accumulators from layer buffer
self._restore_opt(param.name, sub_layer, opt) self._restore_opt(param.name, sub_layer, opt)
......
...@@ -17,9 +17,10 @@ class DygraphPruningCollections(PruningCollections): ...@@ -17,9 +17,10 @@ class DygraphPruningCollections(PruningCollections):
Args: Args:
- model(nn.Layer): The dygraph to be parsed. - model(nn.Layer): The dygraph to be parsed.
- inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`. - inputs(Variable|list|dict): The dummy inputs of target model. It will be used in calling `model.forward(inputs)`.
- skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs): def __init__(self, model, inputs, skip_leaves=True):
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training. # model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs) program = dygraph2program(model, inputs=inputs)
...@@ -28,7 +29,8 @@ class DygraphPruningCollections(PruningCollections): ...@@ -28,7 +29,8 @@ class DygraphPruningCollections(PruningCollections):
_param.name for _param in model.parameters() _param.name for _param in model.parameters()
if len(_param.shape) == 4 if len(_param.shape) == 4
] ]
self._collections = self.create_pruning_collections(params, graph) self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves)
_logger.info("Found {} collections.".format(len(self._collections))) _logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {} _name2values = {}
......
...@@ -139,7 +139,8 @@ class PruningCollections(object): ...@@ -139,7 +139,8 @@ class PruningCollections(object):
params, params,
graph, graph,
skip_stranger=True, skip_stranger=True,
skip_vars=None): skip_vars=None,
skip_leaves=True):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -164,7 +165,8 @@ class PruningCollections(object): ...@@ -164,7 +165,8 @@ class PruningCollections(object):
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True. skip_stranger(bool): Whether to skip current tensor when visit unregistered operators that not in OPS_UNCHANGE_SHAPE. False means visit all unregistered operators by default worker. Default: True.
skip_vars(list<str>): Names of variables that will be skipped. None means skipping all leaves in given graph. '[]' means skipping nothing. Default: None. skip_vars(list<str>): Names of variables that will be skipped. Default: None.
skip_leaves(bool): Whether to skip the last convolution layers.
Returns: Returns:
list<Group>: The groups. list<Group>: The groups.
...@@ -173,12 +175,12 @@ class PruningCollections(object): ...@@ -173,12 +175,12 @@ class PruningCollections(object):
if not isinstance(graph, GraphWrapper): if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
if skip_vars is None: skip_vars = [] if skip_vars is None else skip_vars
skip_vars = self._find_leaves(graph) if skip_leaves:
leaves = self._find_leaves(graph)
skip_vars.extend(leaves)
_logger.warning( _logger.warning(
"Leaves {} will be skipped when parsing graph. You can set skipped variables by option 'skip_vars'.". "Leaves {} will be skipped when parsing graph.".format(leaves))
format(skip_vars))
visited = {} visited = {}
collections = [] collections = []
unsupported_warnings = set() unsupported_warnings = set()
...@@ -234,7 +236,7 @@ class PruningCollections(object): ...@@ -234,7 +236,7 @@ class PruningCollections(object):
class StaticPruningCollections(PruningCollections): class StaticPruningCollections(PruningCollections):
def __init__(self, params, graph, skip_stranger=True): def __init__(self, params, graph, skip_stranger=True, skip_leaves=True):
super(StaticPruningCollections, self).__init__() super(StaticPruningCollections, self).__init__()
self._collections = self.create_pruning_collections( self._collections = self.create_pruning_collections(
params, graph, skip_stranger=skip_stranger) params, graph, skip_stranger=skip_stranger, skip_leaves=skip_leaves)
...@@ -527,77 +527,6 @@ class split(PruneWorker): ...@@ -527,77 +527,6 @@ class split(PruneWorker):
self._visit_and_search(out_var, pruned_axis, transforms) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class concat(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(concat, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis")
if var in self.op.outputs("Out"):
self._visit(var, pruned_axis)
start = 0
if axis == pruned_axis:
for _, in_var in enumerate(self.op.inputs("X")):
idx = []
transoform = {
'src_start': start,
'src_end': start + in_var.shape()[pruned_axis],
'target_start': 0,
'target_end': in_var.shape()[pruned_axis],
'target_len': in_var.shape()[pruned_axis],
'stride': 1
}
start += in_var.shape()[pruned_axis]
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis,
transforms + [transoform])
else:
for _, in_var in enumerate(self.op.inputs("X")):
self._visit(in_var, pruned_axis)
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"):
self._visit(var, pruned_axis)
if axis == pruned_axis:
idx = []
target_start = 0
for v in self.op.inputs("X"):
if v.name() != var.name():
target_start += v.shape()[pruned_axis]
else:
break
target_end = target_start + v.shape()[pruned_axis]
out_var = self.op.outputs("Out")[0]
next_ops = out_var.outputs()
transform = {
'src_start': 0,
'src_end': var.shape()[pruned_axis],
'target_start': target_start,
'target_end': target_end,
'target_len': out_var.shape()[pruned_axis],
'stride': 1
}
self._visit(out_var, pruned_axis)
for op in next_ops:
# The output of concat can be visited repeatedly
c_visited = {}
self._prune_op(
op,
out_var,
pruned_axis,
transforms + [transform],
visited=c_visited)
# Add nodes searched from concat into global visited array.
self.visited.update(c_visited)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class depthwise_conv2d(PruneWorker): class depthwise_conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
...@@ -620,21 +549,21 @@ class depthwise_conv2d(PruneWorker): ...@@ -620,21 +549,21 @@ class depthwise_conv2d(PruneWorker):
pruned_axis) pruned_axis)
# pruning number of filters # pruning number of filters
assert (_filter.shape()[0] % _groups == 0) assert (_filter.shape()[0] % _groups == 0)
stride = _filter.shape()[0] / _groups repeat = int(_filter.shape()[0] / _groups)
self.append_pruned_vars(_filter, 0, transforms + [{ self.append_pruned_vars(_filter, 0, transforms + [{
"stride": stride "repeat": repeat
}]) }])
# kernel_number * groups will be pruned by reducing groups # kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms) self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms + [{ self._visit_and_search(_filter, 0, transforms + [{
"stride": stride "repeat": repeat
}]) }])
# It will not pruning number of kernels in depthwise conv2d, # It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators. # so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms) # self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1) self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms + [{ self._visit_and_search(_out, channel_axis, transforms + [{
"stride": stride "repeat": repeat
}]) }])
elif var == _filter: elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
...@@ -659,20 +588,98 @@ class mul(PruneWorker): ...@@ -659,20 +588,98 @@ class mul(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(mul, self).__init__(op, pruned_params, visited, skip_stranger) super(mul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, trans):
if var in self.op.inputs("X"): x_num_col_dims = self.op.attr("x_num_col_dims")
assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format( y_num_col_dims = self.op.attr("y_num_col_dims")
pruned_axis) x = self.op.inputs("X")[0]
idx = [] y = self.op.inputs("Y")[0]
feature_map_size = var.shape()[2] * var.shape()[3] out = self.op.outputs("Out")[0]
range_idx = np.array(range(feature_map_size)) x_shape = x.shape()
for i in pruned_idx: y_shape = y.shape()
idx += list(range_idx + i * feature_map_size) if var == x:
param_var = self.op.inputs("Y")[0] if y_num_col_dims > 1 and pruned_axis >= x_num_col_dims:
self.append_pruned_vars(param_var, 0, idx) raise UnsupportOpError(
"Unsupport pruning x of mul when y_num_col_dims > 1 and pruned_axis >= x_num_col_dims"
)
tile = 1
repeat = 1
if pruned_axis < x_num_col_dims:
for i in range(0, pruned_axis):
tile *= x_shape[i]
for i in range(pruned_axis + 1, x_num_col_dims):
repeat *= x_shape[i]
self.append_pruned_vars(out, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(out, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
else:
for i in range(x_num_col_dims, pruned_axis):
tile *= x_shape[i]
for i in range(pruned_axis + 1, len(x_shape)):
repeat *= x_shape[i]
self.append_pruned_vars(y, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(y, 0, trans + [{
"tile": tile,
"repeat": repeat
}])
elif var == y:
if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims):
raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
)
tile = 1
repeat = 1
if pruned_axis >= y_num_col_dims:
for i in range(y_num_col_dims, pruned_axis):
tile *= y_shape[i]
for i in range(pruned_axis + 1, len(y_shape)):
repeat *= y_shape[i]
self.append_pruned_vars(out, 1, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(out, 1, trans + [{
"tile": tile,
"repeat": repeat
}])
else:
for i in range(0, pruned_axis):
tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i]
self.append_pruned_vars(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
elif var == out:
if (pruned_axis == 0 and x_num_col_dims != 1) or (
pruned_axis == 1 and (len(y_shape) - y_num_col_dims) != 1):
raise UnsupportOpError(
"Unsupport pruning out of mul when pruned_axis={}; x_num_col_dims: {}; y_num_col_dims: {}; y_shape: {}.".
format(pruned_axis, x_num_col_dims, y_num_col_dims,
y_shape))
for op in param_var.outputs(): if pruned_axis == 0:
self._prune_op(op, param_var, 0, pruned_idx) self.append_pruned_vars(x, 0, trans)
self._visit_and_search(x, 0, trans)
elif pruned_axis == 1:
self.append_pruned_vars(y, len(y_shape) - 1, trans)
self._visit_and_search(y, len(y_shape) - 1, trans)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -684,16 +691,54 @@ class matmul(PruneWorker): ...@@ -684,16 +691,54 @@ class matmul(PruneWorker):
x = self.op.inputs("X")[0] x = self.op.inputs("X")[0]
y = self.op.inputs("Y")[0] y = self.op.inputs("Y")[0]
out = self.op.outputs("Out")[0] out = self.op.outputs("Out")[0]
if var == x and pruned_axis == 1: x_shape_len = len(x.shape())
self.append_pruned_vars(y, 0, pruned_idx) y_shape_len = len(y.shape())
self._visit_and_search(y, 0, pruned_idx) mappings = []
if x_shape_len == 1 and y_shape_len == 1:
mappings = [(0, 0, 0)]
elif x_shape_len == 1 and y_shape_len == 2:
mappings = [(0, 0, -1), (-1, 1, 0)]
elif x_shape_len == 2 and y_shape_len == 2:
mappings = [(0, -1, 0), (1, 0, -1), (-1, 1, 1)]
elif x_shape_len == 3 and y_shape_len == 1:
mappings = [(1, -1, 1), (2, 0, -1)]
elif x_shape_len == 2 and y_shape_len == 3:
mappings = [(0, -1, 1), (1, 1, -1), (-1, 2, 2)]
elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1),
(-1, x_shape_len - 1, x_shape_len - 1)]
if var == x:
for x_i, y_i, out_i in mappings:
if pruned_axis == x_i:
if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
break
if var == y:
for x_i, y_i, out_i in mappings:
if pruned_axis == y_i:
if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx)
self._visit_and_search(x, x_i, pruned_idx)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
break
if var == out: if var == out:
if pruned_axis == 0: for x_i, y_i, out_i in mappings:
self.append_pruned_vars(x, 0, pruned_idx) if pruned_axis == out_i:
self._visit_and_search(x, 0, pruned_idx) if x_i != -1:
elif pruned_axis == 1: self.append_pruned_vars(x, x_i, pruned_idx)
self.append_pruned_vars(y, 1, pruned_idx) self._visit_and_search(x, x_i, pruned_idx)
self._visit_and_search(y, 1, pruned_idx) if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
break
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -859,3 +904,20 @@ class unsqueeze2(PruneWorker): ...@@ -859,3 +904,20 @@ class unsqueeze2(PruneWorker):
squeeze_num += 1 squeeze_num += 1
pruned_axis -= squeeze_num pruned_axis -= squeeze_num
self._visit_and_search(in_var, pruned_axis, transforms) self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register
class average_accumulates(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(average_accumulates, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
in_var = self.op.inputs("param")[0]
out_var_1 = self.op.outputs("out_sum_1")[0]
out_var_2 = self.op.outputs("out_sum_2")[0]
out_var_3 = self.op.outputs("out_sum_3")[0]
if in_var == var:
self.append_pruned_vars(out_var_1, pruned_axis, transforms)
self.append_pruned_vars(out_var_2, pruned_axis, transforms)
self.append_pruned_vars(out_var_3, pruned_axis, transforms)
...@@ -169,20 +169,25 @@ class Pruner(): ...@@ -169,20 +169,25 @@ class Pruner():
for name, axis, pruned_idx, transforms in items: for name, axis, pruned_idx, transforms in items:
src = pruned_idx src = pruned_idx
for trans in transforms: for trans in transforms:
if 'src_start' not in trans:
continue
src_start = trans['src_start']
src_end = trans['src_end']
src_len = src_end - src_start
target_start = trans['target_start']
target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
target = [] target = []
for idx in src: if 'src_start' in trans:
if idx >= src_start and idx < src_end: src_start = trans['src_start']
idx -= src_start src_end = trans['src_end']
target.extend(list(idx + starts)) src_len = src_end - src_start
target_start = trans['target_start']
target_end = trans['target_end']
starts = np.array(range(target_start, target_end, src_len))
for idx in src:
if idx >= src_start and idx < src_end:
idx -= src_start
target.extend(list(idx + starts))
elif "repeat" in trans:
repeat = trans['repeat']
for idx in src:
idx = idx * repeat
target.extend(range(idx, idx + repeat))
src = target src = target
ret.append((name, axis, src)) ret.append((name, axis, src))
return ret return ret
......
...@@ -133,44 +133,68 @@ class TestPruningGroupConv2d(unittest.TestCase): ...@@ -133,44 +133,68 @@ class TestPruningGroupConv2d(unittest.TestCase):
for param in net.parameters(): for param in net.parameters():
if param.name not in shapes: if param.name not in shapes:
shapes[param.name] = param.shape shapes[param.name] = param.shape
assert (shapes[param.name] == param.shape) self.assertTrue(shapes[param.name] == param.shape)
pruner.restore() pruner.restore()
#class TestStrideTransform(unittest.TestCase): from paddle.fluid import ParamAttr
# def __init__(self, methodName='runTest'):
# super(TestStrideTransform, self).__init__(methodName)
# class MulNet(paddle.nn.Layer):
# def runTest(self): """
# with fluid.unique_name.guard(): [3, 36] X conv(x)
# """
# net = paddle.vision.models.mobilenet_v1()
# ratios = {} def __init__(self):
# for param in net.parameters(): super(MulNet, self).__init__()
# if len(param.shape) == 4: self.conv_a = paddle.nn.Conv2D(6, 6, 1)
# ratios[param.name] = 0.5 self.b = self.create_parameter(shape=[3, 36], attr=ParamAttr(name="b"))
# pruners = []
# pruner = L1NormFilterPruner(net, [1, 3, 128, 128]) def forward(self, x):
# pruners.append(pruner) conv_a = self.conv_a(x)
# pruner = FPGMFilterPruner(net, [1, 3, 128, 128]) return paddle.fluid.layers.mul(self.b,
# pruners.append(pruner) conv_a,
# pruner = L2NormFilterPruner(net, [1, 3, 128, 128]) x_num_col_dims=1,
# pruners.append(pruner) y_num_col_dims=3)
#
# shapes = {}
# for pruner in pruners: class TestPruningMul(unittest.TestCase):
# plan = pruner.prune_vars(ratios, 0) def __init__(self, methodName='runTest'):
# for param in net.parameters(): super(TestPruningMul, self).__init__(methodName)
# if param.name not in shapes:
# shapes[param.name] = param.shape def runTest(self):
# assert(shapes[param.name] == param.shape) with fluid.unique_name.guard():
# pruner.restore() net = MulNet()
ratios = {}
ratios['conv2d_0.w_0'] = 0.5
pruners = []
pruner = L1NormFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
pruner = FPGMFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
pruner = L2NormFilterPruner(net, [2, 6, 3, 3], skip_leaves=False)
pruners.append(pruner)
shapes = {
'b': [3, 18],
'conv2d_0.w_0': [3, 6, 1, 1],
'conv2d_0.b_0': [3]
}
for pruner in pruners:
plan = pruner.prune_vars(ratios, 0)
for param in net.parameters():
if param.name not in shapes:
shapes[param.name] = param.shape
self.assertTrue(shapes[param.name] == param.shape)
pruner.restore()
def add_cases(suite): def add_cases(suite):
# suite.addTest(TestStatus()) suite.addTest(TestStatus())
suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"])) suite.addTest(TestFilterPruner(param_names=["conv2d_0.w_0"]))
suite.addTest(TestPruningGroupConv2d()) suite.addTest(TestPruningGroupConv2d())
suite.addTest(TestPruningMul())
def load_tests(loader, standard_tests, pattern): def load_tests(loader, standard_tests, pattern):
......
...@@ -15,6 +15,7 @@ import sys ...@@ -15,6 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
from static_case import StaticCase from static_case import StaticCase
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune import Pruner from paddleslim.prune import Pruner
from static_case import StaticCase from static_case import StaticCase
...@@ -103,5 +104,85 @@ class TestPrune(StaticCase): ...@@ -103,5 +104,85 @@ class TestPrune(StaticCase):
self.assertTrue(shapes[param.name] == param.shape) self.assertTrue(shapes[param.name] == param.shape)
class TestSplit(StaticCase):
def test_split(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(input, 4, 3, "conv2")
split_0, split_1 = paddle.split(conv1, 2, axis=1)
add = split_0 + conv2
out = conv_bn_layer(add, 4, 3, "conv3")
out1 = conv_bn_layer(split_1, 4, 4, "conv4")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv2_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (6, 3, 3, 3),
"conv2_weights": (2, 3, 3, 3),
"conv3_weights": (4, 2, 3, 3),
"conv4_weights": (4, 4, 3, 3),
}
for param in pruned_program.global_block().all_parameters():
if "weights" in param.name and "conv2d" in param.name:
self.assertTrue(shapes[param.name] == param.shape)
class TestMul(StaticCase):
def test_mul(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
fc_0 = paddle.fluid.layers.fc(conv1, size=10)
fc_1 = paddle.fluid.layers.fc(fc_0, size=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
# test backward search of concat
pruned_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv1_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=True,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4, 3, 3, 3),
"fc_0.w_0": (1024, 10),
"fc_1.w_0": (10, 10)
}
for param in pruned_program.global_block().all_parameters():
if param.name in shapes.keys():
self.assertTrue(shapes[param.name] == param.shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -324,6 +324,7 @@ class TestPruneWorker(unittest.TestCase): ...@@ -324,6 +324,7 @@ class TestPruneWorker(unittest.TestCase):
if var.name() not in ret: if var.name() not in ret:
ret[var.name()] = [] ret[var.name()] = []
ret[var.name()].append(axis) ret[var.name()].append(axis)
print(f"excepted: {_ret}; but get {ret}")
self.assertTrue(ret == _ret) self.assertTrue(ret == _ret)
...@@ -372,38 +373,44 @@ class TestElementwiseMul(TestPruneWorker): ...@@ -372,38 +373,44 @@ class TestElementwiseMul(TestPruneWorker):
class TestActivation(TestPruneWorker): class TestActivation(TestPruneWorker):
def __init__(self, methodName="test_prune", def __init__(self,
op=paddle.nn.functional.sigmoid): methodName="check",
op=paddle.nn.functional.sigmoid,
**kwargs):
super(TestActivation, self).__init__(methodName) super(TestActivation, self).__init__(methodName)
self.act = op self.act = op
self.kwargs = kwargs
def define_layer(self, input): def define_layer(self, input):
conv1 = paddle.static.nn.conv2d( conv1 = paddle.static.nn.conv2d(
input, 3, 3, name="conv1", bias_attr=False) input, 3, 3, name="conv1", bias_attr=False)
self.input = conv1 self.input = conv1
tmp = self.act(conv1) tmp = self.act(conv1, **self.kwargs)
self.output = tmp self.output = tmp
conv2 = paddle.static.nn.conv2d( conv2 = paddle.static.nn.conv2d(
tmp, 3, 3, name="conv2", bias_attr=False) tmp, 3, 3, name="conv2", bias_attr=False)
def set_cases(self): def set_cases(self):
self.cases.append((self.in_var, 1, {'conv2.w_0': [1]})) self.cases.append((self.in_var, 1, {'conv2.w_0': [1]}))
self.cases.append((self.out_var, 1, { self.cases.append((self.out_var, 1, {'conv1.w_0': [0], }))
'conv1.w_0': [0],
'conv2.w_0': [1]
}))
def test_prune(self): def check(self):
self.check_in_out() self.check_in_out()
suite = unittest.TestSuite() act_suite = unittest.TestSuite()
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_bilinear)) act_suite.addTest(
suite.addTest(TestActivation(op=paddle.fluid.layers.resize_nearest)) TestActivation(
suite.addTest(TestActivation(op=paddle.floor)) op=paddle.fluid.layers.resize_bilinear, scale=2.))
suite.addTest(TestActivation(op=paddle.scale)) act_suite.addTest(
suite.addTest( TestActivation(
TestActivation(op=paddle.fluid.layers.nn.uniform_random_batch_size_like)) op=paddle.fluid.layers.resize_nearest, scale=2.))
act_suite.addTest(TestActivation(op=paddle.floor))
act_suite.addTest(TestActivation(op=paddle.scale))
act_suite.addTest(
TestActivation(
op=paddle.fluid.layers.nn.uniform_random_batch_size_like,
shape=[8, 8, 16, 16]))
class TestDepthwiseConv2d(TestPruneWorker): class TestDepthwiseConv2d(TestPruneWorker):
...@@ -432,43 +439,161 @@ class TestDepthwiseConv2d(TestPruneWorker): ...@@ -432,43 +439,161 @@ class TestDepthwiseConv2d(TestPruneWorker):
class TestMul(TestPruneWorker): class TestMul(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self,
methodName="check",
x_num_col_dims=1,
y_num_col_dims=1,
ret=[]):
super(TestMul, self).__init__(methodName) super(TestMul, self).__init__(methodName)
self.x_num_col_dims = x_num_col_dims
self.y_num_col_dims = y_num_col_dims
self.ret = ret
def define_layer(self, input): def define_layer(self, input):
x = fluid.data(name="x", shape=[1, 4, 3, 3]) x = fluid.data(name="x", shape=[1, 1, 1, 1])
y = fluid.data(name="y", shape=[36, 7]) y = fluid.data(name="y", shape=[1, 1, 1, 1])
self.input = x self.input = x
out = paddle.fluid.layers.mul(x, y) self.y = y
out = paddle.fluid.layers.mul(x,
y,
x_num_col_dims=self.x_num_col_dims,
y_num_col_dims=self.y_num_col_dims)
self.output = out self.output = out
def set_cases(self): def set_cases(self):
self.cases.append((self.in_var, 1, {'y': [0]})) y = self.graph.var(self.y.name)
x = self.in_var
def test_prune(self): out = self.out_var
self.cases.append((x, 0, self.ret[0]))
self.cases.append((x, 1, self.ret[1]))
self.cases.append((x, 2, self.ret[2]))
self.cases.append((x, 3, self.ret[3]))
self.cases.append((y, 0, self.ret[4]))
self.cases.append((y, 1, self.ret[5]))
self.cases.append((y, 2, self.ret[6]))
self.cases.append((y, 3, self.ret[7]))
self.cases.append((out, 0, self.ret[8]))
self.cases.append((out, 1, self.ret[9]))
def check(self):
self.check_in_out() self.check_in_out()
mul_suite = unittest.TestSuite()
ret = [{
'mul_0.tmp_0': [0]
}] + [{
'y': [0]
}] * 3 + [{}] + [{
'mul_0.tmp_0': [1]
}] * 3 + [{
'x': [0]
}, {}]
mul_suite.addTest(TestMul(x_num_col_dims=1, y_num_col_dims=1, ret=ret))
ret = [{
'mul_0.tmp_0': [0]
}] * 2 + [{}] * 4 + [{
'mul_0.tmp_0': [1]
}] * 2 + [{}] * 2
mul_suite.addTest(TestMul(x_num_col_dims=2, y_num_col_dims=2, ret=ret))
ret = [{
'mul_0.tmp_0': [0]
}] * 3 + [{}] + [{
'x': [3]
}] * 3 + [{
'mul_0.tmp_0': [1]
}] + [{}, {
'y': [3]
}]
mul_suite.addTest(TestMul(x_num_col_dims=3, y_num_col_dims=3, ret=ret))
class TestMatmul(TestPruneWorker): class TestMatmul(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self, methodName="test_prune"):
super(TestMatmul, self).__init__(methodName) super(TestMatmul, self).__init__(methodName)
self.x_shape = [6, 8]
self.y_shape = [8, 7]
def define_layer(self, input): def define_layer(self, input):
x = fluid.data(name="x", shape=[6, 8]) x = fluid.data(name="x", shape=self.x_shape)
y = fluid.data(name="y", shape=[8, 7]) y = fluid.data(name="y", shape=self.y_shape)
self.input = x self.input = x
self.y = y
out = paddle.matmul(x, y) out = paddle.matmul(x, y)
self.output = out self.output = out
def set_cases(self): def set_cases(self):
self.y_var = self.graph.var(self.y.name)
self.cases.append((self.in_var, 1, {'y': [0]})) self.cases.append((self.in_var, 1, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0]})) self.cases.append((self.y_var, 0, {'x': [1]}))
self.cases.append((self.out_var, 1, {'y': [1]})) self.cases.append((self.out_var, 1, {'y': [1]}))
def test_prune(self): def test_prune(self):
self.check_in_out() self.check_in_out()
class TestMatmulCase2(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase2, self).__init__(methodName)
self.x_shape = [8]
self.y_shape = [7]
def set_cases(self):
self.cases.append((self.in_var, 0, {'y': [0]}))
self.cases.append((self.out_var, 0, {'x': [0], 'y': [0]}))
class TestMatmulCase3(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase3, self).__init__(methodName)
self.x_shape = [7]
self.y_shape = [7, 8]
def set_cases(self):
self.cases.append((self.in_var, 0, {'y': [0]}))
self.cases.append((self.out_var, 0, {'y': [1]}))
class TestMatmulCase4(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase4, self).__init__(methodName)
self.x_shape = [8, 7, 7]
self.y_shape = [7]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 2, {'y': [0]}))
self.cases.append((self.out_var, 1, {'x': [1]}))
class TestMatmulCase5(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase5, self).__init__(methodName)
self.x_shape = [7, 7]
self.y_shape = [7, 8, 9]
def set_cases(self):
self.cases.append((self.in_var, 0, {}))
self.cases.append((self.in_var, 1, {'y': [1]}))
self.cases.append((self.out_var, 1, {'x': [0]}))
self.cases.append((self.out_var, 2, {'y': [2]}))
class TestMatmulCase6(TestMatmul):
def __init__(self, methodName="test_prune"):
super(TestMatmulCase6, self).__init__(methodName)
self.x_shape = [7, 7, 7]
self.y_shape = [7, 7, 9]
def set_cases(self):
self.cases.append((self.in_var, 1, {}))
self.cases.append((self.in_var, 2, {'y': [1]}))
self.cases.append((self.out_var, 1, {'x': [1]}))
self.cases.append((self.out_var, 2, {'y': [2]}))
class TestSplit(TestPruneWorker): class TestSplit(TestPruneWorker):
def define_layer(self, input): def define_layer(self, input):
self.input = input self.input = input
...@@ -528,6 +653,33 @@ class TestAdam(TestPruneWorker): ...@@ -528,6 +653,33 @@ class TestAdam(TestPruneWorker):
self.check_in_out() self.check_in_out()
class TestAverageAccumulates(TestPruneWorker):
def define_layer(self, input):
self.input = input
conv1 = paddle.static.nn.conv2d(
input, 3, 8, name="conv1", bias_attr=False)
self.output = conv1
out = paddle.mean(conv1)
opt = paddle.optimizer.Adam()
opt.minimize(out)
model_average = fluid.optimizer.ModelAverage(
0.15, min_average_window=10000, max_average_window=12500)
def set_cases(self):
weight_var = self.graph.var('conv1.w_0')
self.cases.append((weight_var, 0, {
'conv1.w_0': [0],
'conv1.w_0_moment1_0': [0],
'conv1.w_0_moment2_0': [0],
'conv1.w_0_sum_1_0': [0],
'conv1.w_0_sum_2_0': [0],
'conv1.w_0_sum_3_0': [0]
}))
def test_prune(self):
self.check_in_out()
class TestAffineChannel(TestPruneWorker): class TestAffineChannel(TestPruneWorker):
def __init__(self, methodName="test_prune"): def __init__(self, methodName="test_prune"):
super(TestAffineChannel, self).__init__(methodName) super(TestAffineChannel, self).__init__(methodName)
...@@ -555,4 +707,7 @@ class TestAffineChannel(TestPruneWorker): ...@@ -555,4 +707,7 @@ class TestAffineChannel(TestPruneWorker):
if __name__ == '__main__': if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(mul_suite)
runner.run(act_suite)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册