From a075e695359425808dec1ae652e33cc591d5c069 Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 15 Jun 2020 16:01:17 +0800 Subject: [PATCH] Fix unittest of flops and pruning walker (#351) --- tests/test_flops.py | 2 +- tests/test_prune_walker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_flops.py b/tests/test_flops.py index cd16b861..9d50ebc5 100644 --- a/tests/test_flops.py +++ b/tests/test_flops.py @@ -33,7 +33,7 @@ class TestPrune(unittest.TestCase): sum2 = conv4 + sum1 conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") - self.assertTrue(1597440 == flops(main_program)) + self.assertTrue(792576 == flops(main_program)) if __name__ == '__main__': diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index b80f6903..6db1155c 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -57,7 +57,7 @@ class TestPrune(unittest.TestCase): conv_op = graph.var("conv4_weights").outputs()[0] walker = conv2d_walker(conv_op, []) walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) - print walker.pruned_params + print(walker.pruned_params) if __name__ == '__main__': -- GitLab