未验证 提交 8b0005bd 编写于 作者: W whs 提交者: GitHub

Fix all the unittest of pruning. (#346)

上级 39ee8eb3
......@@ -58,7 +58,8 @@ def collect_convs(params, graph, visited={}):
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params)
if len(pruned_params) > 0:
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import AutoPruner
from paddleslim.analysis import flops
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruned_flops = 0.5
pruner = AutoPruner(
main_program,
scope,
place,
params=["conv4_weights"],
init_ratios=[0.5],
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=2,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner")
base_flops = flops(main_program)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
if __name__ == '__main__':
unittest.main()
......@@ -63,12 +63,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
"conv3_weights": (8, 4, 3, 3),
"conv4_weights": (4, 8, 3, 3),
"conv5_weights": (8, 4, 3, 3),
"conv6_weights": (8, 8, 3, 3)
}
for param in main_program.global_block().all_parameters():
......
......@@ -64,12 +64,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
"conv3_weights": (8, 4, 3, 3),
"conv4_weights": (4, 8, 3, 3),
"conv5_weights": (8, 4, 3, 3),
"conv6_weights": (8, 8, 3, 3)
}
for param in main_program.global_block().all_parameters():
......
......@@ -62,12 +62,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
"conv3_weights": (8, 4, 3, 3),
"conv4_weights": (4, 8, 3, 3),
"conv5_weights": (8, 4, 3, 3),
"conv6_weights": (8, 8, 3, 3)
}
for param in main_program.global_block().all_parameters():
......
......@@ -17,7 +17,7 @@ import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddleslim.analysis import sensitivity
from paddleslim.prune import sensitivity
from layers import conv_bn_layer
......@@ -47,13 +47,12 @@ class TestSensitivity(unittest.TestCase):
val_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.test(), batch_size=128)
def eval_func(program, scope):
def eval_func(program):
feeder = fluid.DataFeeder(
feed_list=['image', 'label'], place=place, program=program)
acc_set = []
for data in val_reader():
acc_np = exe.run(program=program,
scope=scope,
feed=feeder.feed(data),
fetch_list=[acc_top1])
acc_set.append(float(acc_np[0]))
......@@ -61,8 +60,7 @@ class TestSensitivity(unittest.TestCase):
print("acc_val_mean: {}".format(acc_val_mean))
return acc_val_mean
sensitivity(eval_program,
fluid.global_scope(), place, ["conv4_weights"], eval_func,
sensitivity(eval_program, place, ["conv4_weights"], eval_func,
"./sensitivities_file")
......
......@@ -63,12 +63,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
"conv3_weights": (8, 4, 3, 3),
"conv4_weights": (4, 8, 3, 3),
"conv5_weights": (8, 4, 3, 3),
"conv6_weights": (8, 8, 3, 3)
}
for param in main_program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册