未验证 提交 9095d4f7 编写于 作者: L lijianshe02 提交者: GitHub

add slim pruning algorithm and ralated unitest test=develop (#119)

* add slim pruning implementation and related unitest test=develop
上级 a784e4fe
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import time import time
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
sys.path.append("../../")
from paddleslim.prune import Pruner, save_model from paddleslim.prune import Pruner, save_model
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.analysis import flops from paddleslim.analysis import flops
...@@ -37,6 +38,7 @@ add_arg('log_period', int, 10, "Log period in batches.") ...@@ -37,6 +38,7 @@ add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.") add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.") add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.") add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -207,7 +209,7 @@ def compress(args): ...@@ -207,7 +209,7 @@ def compress(args):
params = get_pruned_params(args, fluid.default_main_program()) params = get_pruned_params(args, fluid.default_main_program())
_logger.info("FLOPs before pruning: {}".format( _logger.info("FLOPs before pruning: {}".format(
flops(fluid.default_main_program()))) flops(fluid.default_main_program())))
pruner = Pruner() pruner = Pruner(args.criterion)
pruned_val_program, _, _ = pruner.prune( pruned_val_program, _, _ = pruner.prune(
val_program, val_program,
fluid.global_scope(), fluid.global_scope(),
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
...@@ -79,8 +80,8 @@ class Pruner(): ...@@ -79,8 +80,8 @@ class Pruner():
pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num pruned_idx = [0] * pruned_num
else: else:
param_t = np.array(scope.find_var(param).get_tensor()) pruned_idx = self._cal_pruned_idx(
pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0) graph, scope, param, ratio, axis=0)
param = graph.var(param) param = graph.var(param)
conv_op = param.outputs()[0] conv_op = param.outputs()[0]
walker = conv2d_walker( walker = conv2d_walker(
...@@ -130,7 +131,7 @@ class Pruner(): ...@@ -130,7 +131,7 @@ class Pruner():
graph.infer_shape() graph.infer_shape()
return graph.program, param_backup, param_shape_backup return graph.program, param_backup, param_shape_backup
def _cal_pruned_idx(self, param, ratio, axis): def _cal_pruned_idx(self, graph, scope, param, ratio, axis):
""" """
Calculate the index to be pruned on axis by given pruning ratio. Calculate the index to be pruned on axis by given pruning ratio.
...@@ -145,11 +146,26 @@ class Pruner(): ...@@ -145,11 +146,26 @@ class Pruner():
Returns: Returns:
list<int>: The indexes to be pruned on axis. list<int>: The indexes to be pruned on axis.
""" """
prune_num = int(round(param.shape[axis] * ratio))
reduce_dims = [i for i in range(len(param.shape)) if i != axis]
if self.criterion == 'l1_norm': if self.criterion == 'l1_norm':
criterions = np.sum(np.abs(param), axis=tuple(reduce_dims)) param_t = np.array(scope.find_var(param).get_tensor())
prune_num = int(round(param_t.shape[axis] * ratio))
reduce_dims = [i for i in range(len(param_t.shape)) if i != axis]
criterions = np.sum(np.abs(param_t), axis=tuple(reduce_dims))
pruned_idx = criterions.argsort()[:prune_num] pruned_idx = criterions.argsort()[:prune_num]
elif self.criterion == "batch_norm_scale":
param_var = graph.var(param)
conv_op = param_var.outputs()[0]
conv_output = conv_op.outputs("Output")[0]
bn_op = conv_output.outputs()[0]
if bn_op is not None:
bn_scale_param = bn_op.inputs("Scale")[0].name()
bn_scale_np = np.array(
scope.find_var(bn_scale_param).get_tensor())
prune_num = int(round(bn_scale_np.shape[axis] * ratio))
pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num]
else:
raise SystemExit(
"Can't find BatchNorm op after Conv op in Network.")
return pruned_idx return pruned_idx
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
criterion = 'batch_norm_scale'
pruner = Pruner(criterion)
main_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册