未验证 提交 9fe2d24b 编写于 作者: W whs 提交者: GitHub

Fix unittest for static graph (#485)

* Fix unittest for static graph

* Add some comments

* Fix unittest
上级 35431ce1
......@@ -49,8 +49,8 @@ def eval(args):
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
......@@ -95,6 +95,7 @@ def eval(args):
def main():
paddle.enable_static()
args = parser.parse_args()
print_arguments(args)
eval(args)
......
......@@ -3,9 +3,6 @@ import paddle
class StaticCase(unittest.TestCase):
def __init__(self, name):
super(StaticCase, self).__init__()
def setUp(self):
# switch mode
paddle.enable_static()
def runTest(self):
pass
......@@ -70,4 +70,4 @@ class TestAnalysisHelper(StaticCase):
if __name__ == '__main__':
TestAnalysisHelper('test_analysis_helper').test_analysis_helper()
unittest.main()
......@@ -26,7 +26,7 @@ class TestL2Loss(StaticCase):
student_main = fluid.Program()
student_startup = fluid.Program()
with fluid.program_guard(student_main, student_startup):
input = paddle.data(name="image", shape=[None, 3, 224, 224])
input = fluid.data(name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
student_predict = conv1 + conv2
......@@ -34,7 +34,7 @@ class TestL2Loss(StaticCase):
teacher_main = fluid.Program()
teacher_startup = fluid.Program()
with fluid.program_guard(teacher_main, teacher_startup):
input = paddle.data(name="image", shape=[None, 3, 224, 224])
input = fluid.data(name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
......
......@@ -14,6 +14,7 @@
import sys
sys.path.append("../")
import unittest
import paddle
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
from paddleslim.analysis import flops
......@@ -34,6 +35,7 @@ def compute_op_num(program):
class TestRLNAS(StaticCase):
def setUp(self):
paddle.enable_static()
self.init_test_case()
port = np.random.randint(8337, 8773)
self.rlnas = RLNAS(
......
......@@ -16,6 +16,7 @@ sys.path.append("../")
import os
import sys
import unittest
import paddle
import paddle.fluid as fluid
from static_case import StaticCase
from paddleslim.nas import SANAS
......@@ -36,6 +37,7 @@ def compute_op_num(program):
class TestSANAS(StaticCase):
def setUp(self):
paddle.enable_static()
self.init_test_case()
port = np.random.randint(8337, 8773)
self.sanas = SANAS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册