“34e0bd573f47b8fff0be45648b6778791a2c04d6”上不存在“mobile/tools/python/modeltools/yolo/mdl2fluid.py”
未验证 提交 9fe2d24b 编写于 作者: W whs 提交者: GitHub

Fix unittest for static graph (#485)

* Fix unittest for static graph

* Add some comments

* Fix unittest
上级 35431ce1
...@@ -49,8 +49,8 @@ def eval(args): ...@@ -49,8 +49,8 @@ def eval(args):
else: else:
raise ValueError("{} is not supported.".format(args.data)) raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")] image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format( assert args.model in model_list, "{} is not in lists: {}".format(args.model,
args.model, model_list) model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition # model definition
...@@ -95,6 +95,7 @@ def eval(args): ...@@ -95,6 +95,7 @@ def eval(args):
def main(): def main():
paddle.enable_static()
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
eval(args) eval(args)
......
...@@ -3,9 +3,6 @@ import paddle ...@@ -3,9 +3,6 @@ import paddle
class StaticCase(unittest.TestCase): class StaticCase(unittest.TestCase):
def __init__(self, name): def setUp(self):
super(StaticCase, self).__init__() # switch mode
paddle.enable_static() paddle.enable_static()
def runTest(self):
pass
...@@ -70,4 +70,4 @@ class TestAnalysisHelper(StaticCase): ...@@ -70,4 +70,4 @@ class TestAnalysisHelper(StaticCase):
if __name__ == '__main__': if __name__ == '__main__':
TestAnalysisHelper('test_analysis_helper').test_analysis_helper() unittest.main()
...@@ -26,7 +26,7 @@ class TestL2Loss(StaticCase): ...@@ -26,7 +26,7 @@ class TestL2Loss(StaticCase):
student_main = fluid.Program() student_main = fluid.Program()
student_startup = fluid.Program() student_startup = fluid.Program()
with fluid.program_guard(student_main, student_startup): with fluid.program_guard(student_main, student_startup):
input = paddle.data(name="image", shape=[None, 3, 224, 224]) input = fluid.data(name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1") conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2") conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
student_predict = conv1 + conv2 student_predict = conv1 + conv2
...@@ -34,7 +34,7 @@ class TestL2Loss(StaticCase): ...@@ -34,7 +34,7 @@ class TestL2Loss(StaticCase):
teacher_main = fluid.Program() teacher_main = fluid.Program()
teacher_startup = fluid.Program() teacher_startup = fluid.Program()
with fluid.program_guard(teacher_main, teacher_startup): with fluid.program_guard(teacher_main, teacher_startup):
input = paddle.data(name="image", shape=[None, 3, 224, 224]) input = fluid.data(name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1") conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2") conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2 sum1 = conv1 + conv2
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import sys import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
from paddleslim.analysis import flops from paddleslim.analysis import flops
...@@ -34,6 +35,7 @@ def compute_op_num(program): ...@@ -34,6 +35,7 @@ def compute_op_num(program):
class TestRLNAS(StaticCase): class TestRLNAS(StaticCase):
def setUp(self): def setUp(self):
paddle.enable_static()
self.init_test_case() self.init_test_case()
port = np.random.randint(8337, 8773) port = np.random.randint(8337, 8773)
self.rlnas = RLNAS( self.rlnas = RLNAS(
......
...@@ -16,6 +16,7 @@ sys.path.append("../") ...@@ -16,6 +16,7 @@ sys.path.append("../")
import os import os
import sys import sys
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from static_case import StaticCase from static_case import StaticCase
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
...@@ -36,6 +37,7 @@ def compute_op_num(program): ...@@ -36,6 +37,7 @@ def compute_op_num(program):
class TestSANAS(StaticCase): class TestSANAS(StaticCase):
def setUp(self): def setUp(self):
paddle.enable_static()
self.init_test_case() self.init_test_case()
port = np.random.randint(8337, 8773) port = np.random.randint(8337, 8773)
self.sanas = SANAS( self.sanas = SANAS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册