提交 215180cb 编写于 作者: W wanghaoshuang

Make unittest readable.

上级 913b1faf
......@@ -10,6 +10,13 @@ class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
......@@ -42,12 +49,12 @@ class TestPrune(unittest.TestCase):
param_shape_backup=None)
shapes = {
"conv5_weights": (8L, 4L, 3L, 3L),
"conv1_weights": (4L, 3L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L)
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册