From 9fe2d24b0ca8849a817ade51ab083a0cd303360f Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 28 Oct 2020 20:08:51 +0800 Subject: [PATCH] Fix unittest for static graph (#485) * Fix unittest for static graph * Add some comments * Fix unittest --- demo/prune/eval.py | 5 +++-- tests/static_case.py | 7 ++----- tests/test_analysis_helper.py | 2 +- tests/test_l2_loss.py | 4 ++-- tests/test_rl_nas.py | 2 ++ tests/test_sa_nas.py | 2 ++ 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/demo/prune/eval.py b/demo/prune/eval.py index b2c9ea26..a93a056a 100644 --- a/demo/prune/eval.py +++ b/demo/prune/eval.py @@ -49,8 +49,8 @@ def eval(args): else: raise ValueError("{} is not supported.".format(args.data)) image_shape = [int(m) for m in image_shape.split(",")] - assert args.model in model_list, "{} is not in lists: {}".format( - args.model, model_list) + assert args.model in model_list, "{} is not in lists: {}".format(args.model, + model_list) image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') # model definition @@ -95,6 +95,7 @@ def eval(args): def main(): + paddle.enable_static() args = parser.parse_args() print_arguments(args) eval(args) diff --git a/tests/static_case.py b/tests/static_case.py index 96e8b1bb..ed322a91 100644 --- a/tests/static_case.py +++ b/tests/static_case.py @@ -3,9 +3,6 @@ import paddle class StaticCase(unittest.TestCase): - def __init__(self, name): - super(StaticCase, self).__init__() + def setUp(self): + # switch mode paddle.enable_static() - - def runTest(self): - pass diff --git a/tests/test_analysis_helper.py b/tests/test_analysis_helper.py index abce65f6..5565052d 100644 --- a/tests/test_analysis_helper.py +++ b/tests/test_analysis_helper.py @@ -70,4 +70,4 @@ class TestAnalysisHelper(StaticCase): if __name__ == '__main__': - TestAnalysisHelper('test_analysis_helper').test_analysis_helper() + unittest.main() diff --git a/tests/test_l2_loss.py b/tests/test_l2_loss.py index ef15c1f1..8c99f13b 100644 --- a/tests/test_l2_loss.py +++ b/tests/test_l2_loss.py @@ -26,7 +26,7 @@ class TestL2Loss(StaticCase): student_main = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_main, student_startup): - input = paddle.data(name="image", shape=[None, 3, 224, 224]) + input = fluid.data(name="image", shape=[None, 3, 224, 224]) conv1 = conv_bn_layer(input, 8, 3, "conv1") conv2 = conv_bn_layer(conv1, 8, 3, "conv2") student_predict = conv1 + conv2 @@ -34,7 +34,7 @@ class TestL2Loss(StaticCase): teacher_main = fluid.Program() teacher_startup = fluid.Program() with fluid.program_guard(teacher_main, teacher_startup): - input = paddle.data(name="image", shape=[None, 3, 224, 224]) + input = fluid.data(name="image", shape=[None, 3, 224, 224]) conv1 = conv_bn_layer(input, 8, 3, "conv1") conv2 = conv_bn_layer(conv1, 8, 3, "conv2") sum1 = conv1 + conv2 diff --git a/tests/test_rl_nas.py b/tests/test_rl_nas.py index b1dc7be2..fd6323df 100644 --- a/tests/test_rl_nas.py +++ b/tests/test_rl_nas.py @@ -14,6 +14,7 @@ import sys sys.path.append("../") import unittest +import paddle import paddle.fluid as fluid from paddleslim.nas import RLNAS from paddleslim.analysis import flops @@ -34,6 +35,7 @@ def compute_op_num(program): class TestRLNAS(StaticCase): def setUp(self): + paddle.enable_static() self.init_test_case() port = np.random.randint(8337, 8773) self.rlnas = RLNAS( diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index 3d8de844..3a8e4f3d 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -16,6 +16,7 @@ sys.path.append("../") import os import sys import unittest +import paddle import paddle.fluid as fluid from static_case import StaticCase from paddleslim.nas import SANAS @@ -36,6 +37,7 @@ def compute_op_num(program): class TestSANAS(StaticCase): def setUp(self): + paddle.enable_static() self.init_test_case() port = np.random.randint(8337, 8773) self.sanas = SANAS( -- GitLab