提交 48744e8b 编写于 作者: W wanghaoshuang

Merge branch 'criterion' into pytorch

......@@ -25,7 +25,13 @@ from .prune_walker import *
from ..prune import prune_walker
from .prune_io import *
from ..prune import prune_io
from .group_param import *
from ..prune import group_param
from .criterion import *
from ..prune import criterion
from .importance_sort import *
from ..prune import importance_sort
__all__ = []
__all__ += pruner.__all__
......@@ -34,3 +40,6 @@ __all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__
__all__ += prune_walker.__all__
__all__ += prune_io.__all__
__all__ += group_param.__all__
__all__ += criterion.__all__
__all__ += importance_sort.__all__
"""Define some functions to compute the importance of structure to be pruned.
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from ..common import get_logger
__all__ = ["l1_norm"]
_logger = get_logger(__name__, level=logging.INFO)
def l1_norm(group):
"""Compute l1-norm scores of parameter on given axis.
This function return a list of parameters' l1-norm scores on given axis.
Each element of list is a tuple with format (name, axis, score) in which 'name' is parameter's name
and 'axis' is the axis reducing on and `score` is a np.array storing the l1-norm of strucure on `axis`.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, values, axis)' in which `name` is the parameter's name and
and `values` is the values of parameter and `axis` is the axis reducing on pruning on.
Returns:
list: A list of tuple storing l1-norm on given axis.
"""
scores = []
for name, value, axis in group:
reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score))
return scores
"""Define some functions to collect ralated parameters into groups."""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core import GraphWrapper
from .prune_walker import conv2d as conv2d_walker
__all__ = ["collect_convs"]
def collect_convs(params, graph):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
.. code-block:: text
conv1->conv2->conv3->conv4
As shown above, the demo has 4 convolution layers. And the shape of convolution's parameter is `[out_channel, in_channel, filter_size, filter_size]`. If parameter of `conv1` was pruned on axis 0, then the parameter of `conv2` should be pruned on axis 1. So the `conv1` and `conv2` is a group that can be represented as:
.. code-block:: python
[("conv1", 0), ("conv2", 1)]
If `params` is `["conv1", "conv2"]`, then the returned groups is:
.. code-block:: python
[[("conv1", 0), ("conv2", 1)],
[("conv2", 0), ("conv3", 1)]]
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
"""
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
groups = []
for param in params:
visited = {}
pruned_params = []
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, _ in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis))
if not repeat_group:
uniq_groups.append(simple_group)
return uniq_groups
"""Define some functions to sort substructures of parameter by importance.
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from ..core import GraphWrapper
from ..common import get_logger
__all__ = ["channel_score_sort", "batch_norm_scale_sort"]
def channel_score_sort(group, graph):
"""Sort channels of convolution by importance.
This function return a list of parameters' sorted indexes on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name
and 'axis' is the axis pruning on and `indexes` is sorted indexes.
The sorted indexes is computed by below steps:
step1: Find the first convolution layer in given group.
step2: Get the scores of first convolution's channels.
step3: Get sorted indexes by calling scores.argsort().
step4: All the parameters in group share the same sorted indexes computed in step3.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
graph(GraphWrapper): The graph is an auxiliary for sorting. It won't be used in this function.
Returns:
list: sorted indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()
idxs = []
for name, axis, score in group:
idxs.append((name, axis, sorted_idx))
return idxs
def batch_norm_scale_sort(group, graph):
"""Sort channels of convolution by scales in batch norm layer.
This function return a list of parameters' sorted indexes on given axis.
Each element of list is a tuple with format (name, axis, indexes) in which 'name' is parameter's name
and 'axis' is the axis pruning on and `indexes` is sorted indexes.
The sorted indexes is computed by below steps:
step1: Find the batch norm layer after the first convolution in given group.
step2: Get the scales of the batch norm layer.
step3: Get sorted indexes by calling `scales.argsort()`.
step4: All the parameters in group share the same sorted indexes computed in step3.
Args:
group(list): A group of parameters. The first parameter of the group is convolution layer's weight
while the others are parameters affected by pruning the first one. Each parameter in group
is represented as tuple '(name, axis, score)' in which `name` is the parameter's name and
`axis` is the axis pruning on and `score` is a np.array storing the importance of strucure
on `axis`. Show as below:
.. code-block: text
[("conv1_weights", 0, [0.7, 0.5, 0.6]), ("conv1_bn.scale", 0, [0.1, 0.2, 0.4])]
The shape of "conv1_weights" is `[out_channel, in_channel, filter_size, filter_size]`, so
`[0.7, 0.5, 0.6]` are the importance sores of each output channel in "conv1_weights"
while axis is 0.
graph(GraphWrapper): The graph is an auxiliary for sorting. It is used to find
the batch norm layer after given convolution layer.
Returns:
list: sorted indexes
"""
assert (isinstance(graph, GraphWrapper))
# step1: Get first convolution
conv_weight, axis, score = group[0]
param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0]
# step2: Get bn layer after first convolution
conv_output = conv_op.outputs("Output")[0]
bn_op = conv_output.outputs()[0]
if bn_op is not None:
bn_scale_param = bn_op.inputs("Scale")[0].name()
else:
raise SystemExit("Can't find BatchNorm op after Conv op in Network.")
# steps3: Find score of bn and compute sorted indexes
sorted_idx = None
for name, axis, score in group:
if name == bn_scale_param:
sorted_idx = score.argsort()
break
# step4: Share the sorted indexes with all the parameter in group
idxs = []
if sorted_idx is not None:
for name, axis, score in group:
idxs.append((name, axis, sorted_idx))
return idxs
......@@ -17,13 +17,10 @@ import sys
import numpy as np
import paddle.fluid as fluid
import copy
from ..core import GraphWrapper
try:
from ..core import DyGraph
except Exception as e:
pass
from .prune_walker import conv2d as conv2d_walker
from .dy_prune_walker import Conv2d as dy_conv2d_walker
from ..core import VarWrapper, OpWrapper, GraphWrapper
from .group_param import collect_convs
from .criterion import l1_norm
from .importance_sort import channel_score_sort, batch_norm_scale
from ..common import get_logger
__all__ = ["Pruner"]
......@@ -35,12 +32,21 @@ class Pruner():
"""The pruner used to prune channels of convolution.
Args:
criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently.
criterion(str|function): the criterion used to sort channels for pruning.
channel_sortor(str|function):
"""
def __init__(self, criterion="l1_norm"):
def __init__(self, criterion="l1_norm", channel_sortor="channel_score"):
self.criterion = criterion
self.channel_sortor = channel_sortor
if criterion == "l1_norm":
self.criterion = l1_norm
if channel_sortor == "channel_score":
self.channel_sortor = channel_score_sort
elif channel_sortor == "batch_norm_scale":
self.channel_sortor = batch_norm_scale_sort
def prune(self,
graph,
......@@ -86,26 +92,40 @@ class Pruner():
visited = {}
pruned_params = []
for param, ratio in zip(params, ratios):
group = collect_convs([param], graph)[0] # [(name, axis)]
if only_graph:
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, aixs in group:
pruned_params.append((name, axis, pruned_idx))
else:
pruned_idx = self._cal_pruned_idx(
graph, scope, param, ratio, axis=0)
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx)
group_values = []
for name, axis in group:
values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis))
scores = self.criterion(
group_with_values) # [(name, axis, score)]
group_idx = self.channel_sortor(
scores, graph=graph) # [(name, axis, soted_idx)]
for param, pruned_axis, pruned_idx in group_idx:
pruned_num = len(pruned_idx) * ratio
pruned_params.append((
param, pruned_axis,
pruned_idx[:pruned_num])) # [(name, axis, pruned_idx)]
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
if param.name() not in merge_pruned_params:
merge_pruned_params[param.name()] = {}
if pruned_axis not in merge_pruned_params[param.name()]:
merge_pruned_params[param.name()][pruned_axis] = []
merge_pruned_params[param.name()][pruned_axis].append(pruned_idx)
if param not in merge_pruned_params:
merge_pruned_params[param] = {}
if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from layers import conv_bn_layer
from paddleslim.prune import collect_convs
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 18)
self.assertTrue(len(groups[1]) == 6)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册