test_datasets.py 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2

22
import paddle.vision.transforms as T
23
from paddle.vision.datasets import DatasetFolder, ImageFolder, MNIST, FashionMNIST, Flowers
24
from paddle.dataset.common import _check_exists_and_download
W
wanghuancoder 已提交
25
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
26 27 28


class TestFolderDatasets(unittest.TestCase):
29

30 31 32 33 34 35 36 37 38 39 40 41 42 43
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        self.empty_dir = tempfile.mkdtemp()
        for i in range(2):
            sub_dir = os.path.join(self.data_dir, 'class_' + str(i))
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir)
            for j in range(2):
                fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
                cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)

    def tearDown(self):
        shutil.rmtree(self.data_dir)

W
wanghuancoder 已提交
44
    def func_test_dataset(self):
45 46 47 48 49 50 51 52 53 54 55 56
        dataset_folder = DatasetFolder(self.data_dir)

        for _ in dataset_folder:
            pass

        assert len(dataset_folder) == 4
        assert len(dataset_folder.classes) == 2

        dataset_folder = DatasetFolder(self.data_dir)
        for _ in dataset_folder:
            pass

W
wanghuancoder 已提交
57 58 59 60 61 62
    def test_dataset(self):
        with _test_eager_guard():
            self.func_test_dataset()
        self.func_test_dataset()

    def func_test_folder(self):
63 64 65 66 67 68 69 70 71 72 73
        loader = ImageFolder(self.data_dir)

        for _ in loader:
            pass

        loader = ImageFolder(self.data_dir)
        for _ in loader:
            pass

        assert len(loader) == 4

W
wanghuancoder 已提交
74 75 76 77 78 79
    def test_folder(self):
        with _test_eager_guard():
            self.func_test_folder()
        self.func_test_folder()

    def func_test_transform(self):
80

81 82 83 84 85 86 87 88 89 90 91 92 93
        def fake_transform(img):
            return img

        transfrom = fake_transform
        dataset_folder = DatasetFolder(self.data_dir, transform=transfrom)

        for _ in dataset_folder:
            pass

        loader = ImageFolder(self.data_dir, transform=transfrom)
        for _ in loader:
            pass

W
wanghuancoder 已提交
94 95 96 97 98 99
    def test_transform(self):
        with _test_eager_guard():
            self.func_test_transform()
        self.func_test_transform()

    def func_test_errors(self):
100 101 102 103 104 105 106 107
        with self.assertRaises(RuntimeError):
            ImageFolder(self.empty_dir)
        with self.assertRaises(RuntimeError):
            DatasetFolder(self.empty_dir)

        with self.assertRaises(ValueError):
            _check_exists_and_download('temp_paddle', None, None, None, False)

W
wanghuancoder 已提交
108 109 110 111 112
    def test_errors(self):
        with _test_eager_guard():
            self.func_test_errors()
        self.func_test_errors()

113 114

class TestMNISTTest(unittest.TestCase):
115

W
wanghuancoder 已提交
116
    def func_test_main(self):
117 118
        transform = T.Transpose()
        mnist = MNIST(mode='test', transform=transform)
119 120
        self.assertTrue(len(mnist) == 10000)

121 122 123 124 125 126 127
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
128

W
wanghuancoder 已提交
129 130 131 132 133
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

134 135

class TestMNISTTrain(unittest.TestCase):
136

W
wanghuancoder 已提交
137
    def func_test_main(self):
138 139
        transform = T.Transpose()
        mnist = MNIST(mode='train', transform=transform)
140 141
        self.assertTrue(len(mnist) == 60000)

142 143 144 145 146 147 148
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
149

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        # test cv2 backend
        mnist = MNIST(mode='train', transform=transform, backend='cv2')
        self.assertTrue(len(mnist) == 60000)

        for i in range(len(mnist)):
            image, label = mnist[i]
            self.assertTrue(image.shape[0] == 1)
            self.assertTrue(image.shape[1] == 28)
            self.assertTrue(image.shape[2] == 28)
            self.assertTrue(label.shape[0] == 1)
            self.assertTrue(0 <= int(label) <= 9)
            break

        with self.assertRaises(ValueError):
            mnist = MNIST(mode='train', transform=transform, backend=1)

W
wanghuancoder 已提交
166 167 168 169 170
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

171

L
LielinJiang 已提交
172
class TestFASHIONMNISTTest(unittest.TestCase):
173

W
wanghuancoder 已提交
174
    def func_test_main(self):
L
LielinJiang 已提交
175 176 177 178
        transform = T.Transpose()
        mnist = FashionMNIST(mode='test', transform=transform)
        self.assertTrue(len(mnist) == 10000)

179 180 181 182 183 184 185
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
L
LielinJiang 已提交
186

W
wanghuancoder 已提交
187 188 189 190 191
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

L
LielinJiang 已提交
192 193

class TestFASHIONMNISTTrain(unittest.TestCase):
194

W
wanghuancoder 已提交
195
    def func_test_main(self):
L
LielinJiang 已提交
196 197 198 199
        transform = T.Transpose()
        mnist = FashionMNIST(mode='train', transform=transform)
        self.assertTrue(len(mnist) == 60000)

200 201 202 203 204 205 206
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
L
LielinJiang 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223

        # test cv2 backend
        mnist = FashionMNIST(mode='train', transform=transform, backend='cv2')
        self.assertTrue(len(mnist) == 60000)

        for i in range(len(mnist)):
            image, label = mnist[i]
            self.assertTrue(image.shape[0] == 1)
            self.assertTrue(image.shape[1] == 28)
            self.assertTrue(image.shape[2] == 28)
            self.assertTrue(label.shape[0] == 1)
            self.assertTrue(0 <= int(label) <= 9)
            break

        with self.assertRaises(ValueError):
            mnist = FashionMNIST(mode='train', transform=transform, backend=1)

W
wanghuancoder 已提交
224 225 226 227 228 229
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

    def func_test_dataset_value(self):
L
LielinJiang 已提交
230 231 232 233 234 235
        fmnist = FashionMNIST(mode='train')
        value = np.mean([np.array(x[0]) for x in fmnist])

        # 72.94035223214286 was getted from competitive products
        np.testing.assert_allclose(value, 72.94035223214286)

W
wanghuancoder 已提交
236 237 238 239 240
    def test_dataset_value(self):
        with _test_eager_guard():
            self.func_test_dataset_value()
        self.func_test_dataset_value()

L
LielinJiang 已提交
241

242
class TestFlowersTrain(unittest.TestCase):
243

W
wanghuancoder 已提交
244
    def func_test_main(self):
245 246 247 248 249 250 251
        flowers = Flowers(mode='train')
        self.assertTrue(len(flowers) == 6149)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 6149)
        image, label = flowers[idx]
252
        image = np.array(image)
253 254 255 256
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

W
wanghuancoder 已提交
257 258 259 260 261
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

262 263

class TestFlowersValid(unittest.TestCase):
264

W
wanghuancoder 已提交
265
    def func_test_main(self):
266 267 268 269 270 271 272
        flowers = Flowers(mode='valid')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]
273
        image = np.array(image)
274 275 276 277
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

W
wanghuancoder 已提交
278 279 280 281 282
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

283 284

class TestFlowersTest(unittest.TestCase):
285

W
wanghuancoder 已提交
286
    def func_test_main(self):
287 288 289 290 291 292 293
        flowers = Flowers(mode='test')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]
294
        image = np.array(image)
295 296 297 298
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
        # test cv2 backend
        flowers = Flowers(mode='test', backend='cv2')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]

        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

        with self.assertRaises(ValueError):
            flowers = Flowers(mode='test', backend=1)

W
wanghuancoder 已提交
315 316 317 318 319
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

320 321 322

if __name__ == '__main__':
    unittest.main()