提交 fb82aadb 编写于 作者: D dengkaipeng

add unittest

上级 49784f25
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
from paddle import fluid from paddle import fluid
from paddle.fluid.optimizer import Momentum from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from .vision.datasets import MNIST as MnistDataset from vision.datasets import MNIST as MnistDataset
from model import Model, CrossEntropy, Input, set_device from model import Model, CrossEntropy, Input, set_device
from metrics import Accuracy from metrics import Accuracy
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
# when test, you should add hapi root path to the PYTHONPATH, # when test, you should add hapi root path to the PYTHONPATH,
# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH # export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH
import unittest import unittest
import numpy as np
from datasets.folder import DatasetFolder from vision.datasets import *
class TestFolderDatasets(unittest.TestCase): class TestFolderDatasets(unittest.TestCase):
...@@ -30,5 +32,71 @@ class TestFolderDatasets(unittest.TestCase): ...@@ -30,5 +32,71 @@ class TestFolderDatasets(unittest.TestCase):
assert len(dataset_folder.classes) == 2 assert len(dataset_folder.classes) == 2
class TestMNISTTest(unittest.TestCase):
def test_main(self):
mnist = MNIST(mode='test')
self.assertTrue(len(mnist) == 10000)
for i in range(len(mnist)):
image, label = mnist[i]
self.assertTrue(image.shape[0] == 784)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
class TestMNISTTrain(unittest.TestCase):
def test_main(self):
mnist = MNIST(mode='train')
self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)):
image, label = mnist[i]
self.assertTrue(image.shape[0] == 784)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
class TestFlowersTrain(unittest.TestCase):
def test_main(self):
flowers = Flowers(mode='train')
self.assertTrue(len(flowers) == 6149)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 6149)
image, label = flowers[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
class TestFlowersValid(unittest.TestCase):
def test_main(self):
flowers = Flowers(mode='valid')
self.assertTrue(len(flowers) == 1020)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
class TestFlowersTest(unittest.TestCase):
def test_main(self):
flowers = Flowers(mode='test')
self.assertTrue(len(flowers) == 1020)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册