diff --git a/tests/test_prune.py b/tests/test_prune.py index f86bfe6d700632da33568325fd3b6f750cab716b..90affb7652ae712f47053cd6d88374d5733ea9ba 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -10,6 +10,13 @@ class TestPrune(unittest.TestCase): def test_prune(self): main_program = fluid.Program() startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels with fluid.program_guard(main_program, startup_program): input = fluid.data(name="image", shape=[None, 3, 16, 16]) conv1 = conv_bn_layer(input, 8, 3, "conv1") @@ -42,12 +49,12 @@ class TestPrune(unittest.TestCase): param_shape_backup=None) shapes = { - "conv5_weights": (8L, 4L, 3L, 3L), "conv1_weights": (4L, 3L, 3L, 3L), - "conv6_weights": (8L, 8L, 3L, 3L), - "conv3_weights": (8L, 4L, 3L, 3L), "conv2_weights": (4L, 4L, 3L, 3L), - "conv4_weights": (4L, 8L, 3L, 3L) + "conv3_weights": (8L, 4L, 3L, 3L), + "conv4_weights": (4L, 8L, 3L, 3L), + "conv5_weights": (8L, 4L, 3L, 3L), + "conv6_weights": (8L, 8L, 3L, 3L) } for param in main_program.global_block().all_parameters():