未验证 提交 52502c06 编写于 作者: L lijianshe02 提交者: GitHub

add optimal-threshold prune strategy implementation test=develop (#183)

上级 8d1dd54b
...@@ -172,9 +172,9 @@ class Pruner(): ...@@ -172,9 +172,9 @@ class Pruner():
dist_sum_list.append((dist_sum, out_i)) dist_sum_list.append((dist_sum, out_i))
min_gm_filters = sorted( min_gm_filters = sorted(
dist_sum_list, key=lambda x: x[0])[:prune_num] dist_sum_list, key=lambda x: x[0])[:prune_num]
pruned_idx = [x[1] for x in min_gm_filters] pruned_idx = np.array([x[1] for x in min_gm_filters])
elif self.criterion == "batch_norm_scale": elif self.criterion == "batch_norm_scale" or self.criterion == "optimal_threshold":
param_var = graph.var(param) param_var = graph.var(param)
conv_op = param_var.outputs()[0] conv_op = param_var.outputs()[0]
conv_output = conv_op.outputs("Output")[0] conv_output = conv_op.outputs("Output")[0]
...@@ -183,8 +183,28 @@ class Pruner(): ...@@ -183,8 +183,28 @@ class Pruner():
bn_scale_param = bn_op.inputs("Scale")[0].name() bn_scale_param = bn_op.inputs("Scale")[0].name()
bn_scale_np = np.array( bn_scale_np = np.array(
scope.find_var(bn_scale_param).get_tensor()) scope.find_var(bn_scale_param).get_tensor())
prune_num = int(round(bn_scale_np.shape[axis] * ratio)) if self.criterion == "batch_norm_scale":
pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num] prune_num = int(round(bn_scale_np.shape[axis] * ratio))
pruned_idx = np.abs(bn_scale_np).argsort()[:prune_num]
elif self.criterion == "optimal_threshold":
def get_optimal_threshold(weight, percent=0.001):
weight[weight < 1e-18] = 1e-18
weight_sorted = np.sort(weight)
weight_square = weight_sorted**2
total_sum = weight_square.sum()
acc_sum = 0
for i in range(weight_square.size):
acc_sum += weight_square[i]
if acc_sum / total_sum > percent:
break
th = (weight_sorted[i - 1] + weight_sorted[i]
) / 2 if i > 0 else 0
return th
optimal_th = get_optimal_threshold(bn_scale_np, 0.12)
pruned_idx = np.squeeze(
np.argwhere(bn_scale_np < optimal_th))
else: else:
raise SystemExit( raise SystemExit(
"Can't find BatchNorm op after Conv op in Network.") "Can't find BatchNorm op after Conv op in Network.")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
criterion = 'optimal_threshold'
pruner = Pruner(criterion)
main_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name,
param.shape))
#self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册