提交 215180cb 编写于 作者: W wanghaoshuang

Make unittest readable.

上级 913b1faf
...@@ -10,6 +10,13 @@ class TestPrune(unittest.TestCase): ...@@ -10,6 +10,13 @@ class TestPrune(unittest.TestCase):
def test_prune(self): def test_prune(self):
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16]) input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1") conv1 = conv_bn_layer(input, 8, 3, "conv1")
...@@ -42,12 +49,12 @@ class TestPrune(unittest.TestCase): ...@@ -42,12 +49,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None) param_shape_backup=None)
shapes = { shapes = {
"conv5_weights": (8L, 4L, 3L, 3L),
"conv1_weights": (4L, 3L, 3L, 3L), "conv1_weights": (4L, 3L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L), "conv2_weights": (4L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L) "conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
} }
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册