test_datasets.py 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2

22
import paddle.vision.transforms as T
23
from paddle.vision.datasets import DatasetFolder, ImageFolder, MNIST, FashionMNIST, Flowers
24
from paddle.dataset.common import _check_exists_and_download
W
wanghuancoder 已提交
25
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


class TestFolderDatasets(unittest.TestCase):
    def setUp(self):
        self.data_dir = tempfile.mkdtemp()
        self.empty_dir = tempfile.mkdtemp()
        for i in range(2):
            sub_dir = os.path.join(self.data_dir, 'class_' + str(i))
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir)
            for j in range(2):
                fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
                cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)

    def tearDown(self):
        shutil.rmtree(self.data_dir)

W
wanghuancoder 已提交
43
    def func_test_dataset(self):
44 45 46 47 48 49 50 51 52 53 54 55
        dataset_folder = DatasetFolder(self.data_dir)

        for _ in dataset_folder:
            pass

        assert len(dataset_folder) == 4
        assert len(dataset_folder.classes) == 2

        dataset_folder = DatasetFolder(self.data_dir)
        for _ in dataset_folder:
            pass

W
wanghuancoder 已提交
56 57 58 59 60 61
    def test_dataset(self):
        with _test_eager_guard():
            self.func_test_dataset()
        self.func_test_dataset()

    def func_test_folder(self):
62 63 64 65 66 67 68 69 70 71 72
        loader = ImageFolder(self.data_dir)

        for _ in loader:
            pass

        loader = ImageFolder(self.data_dir)
        for _ in loader:
            pass

        assert len(loader) == 4

W
wanghuancoder 已提交
73 74 75 76 77 78
    def test_folder(self):
        with _test_eager_guard():
            self.func_test_folder()
        self.func_test_folder()

    def func_test_transform(self):
79 80 81 82 83 84 85 86 87 88 89 90 91
        def fake_transform(img):
            return img

        transfrom = fake_transform
        dataset_folder = DatasetFolder(self.data_dir, transform=transfrom)

        for _ in dataset_folder:
            pass

        loader = ImageFolder(self.data_dir, transform=transfrom)
        for _ in loader:
            pass

W
wanghuancoder 已提交
92 93 94 95 96 97
    def test_transform(self):
        with _test_eager_guard():
            self.func_test_transform()
        self.func_test_transform()

    def func_test_errors(self):
98 99 100 101 102 103 104 105
        with self.assertRaises(RuntimeError):
            ImageFolder(self.empty_dir)
        with self.assertRaises(RuntimeError):
            DatasetFolder(self.empty_dir)

        with self.assertRaises(ValueError):
            _check_exists_and_download('temp_paddle', None, None, None, False)

W
wanghuancoder 已提交
106 107 108 109 110
    def test_errors(self):
        with _test_eager_guard():
            self.func_test_errors()
        self.func_test_errors()

111 112

class TestMNISTTest(unittest.TestCase):
W
wanghuancoder 已提交
113
    def func_test_main(self):
114 115
        transform = T.Transpose()
        mnist = MNIST(mode='test', transform=transform)
116 117
        self.assertTrue(len(mnist) == 10000)

118 119 120 121 122 123 124
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
125

W
wanghuancoder 已提交
126 127 128 129 130
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

131 132

class TestMNISTTrain(unittest.TestCase):
W
wanghuancoder 已提交
133
    def func_test_main(self):
134 135
        transform = T.Transpose()
        mnist = MNIST(mode='train', transform=transform)
136 137
        self.assertTrue(len(mnist) == 60000)

138 139 140 141 142 143 144
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
145

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        # test cv2 backend
        mnist = MNIST(mode='train', transform=transform, backend='cv2')
        self.assertTrue(len(mnist) == 60000)

        for i in range(len(mnist)):
            image, label = mnist[i]
            self.assertTrue(image.shape[0] == 1)
            self.assertTrue(image.shape[1] == 28)
            self.assertTrue(image.shape[2] == 28)
            self.assertTrue(label.shape[0] == 1)
            self.assertTrue(0 <= int(label) <= 9)
            break

        with self.assertRaises(ValueError):
            mnist = MNIST(mode='train', transform=transform, backend=1)

W
wanghuancoder 已提交
162 163 164 165 166
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

167

L
LielinJiang 已提交
168
class TestFASHIONMNISTTest(unittest.TestCase):
W
wanghuancoder 已提交
169
    def func_test_main(self):
L
LielinJiang 已提交
170 171 172 173
        transform = T.Transpose()
        mnist = FashionMNIST(mode='test', transform=transform)
        self.assertTrue(len(mnist) == 10000)

174 175 176 177 178 179 180
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
L
LielinJiang 已提交
181

W
wanghuancoder 已提交
182 183 184 185 186
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

L
LielinJiang 已提交
187 188

class TestFASHIONMNISTTrain(unittest.TestCase):
W
wanghuancoder 已提交
189
    def func_test_main(self):
L
LielinJiang 已提交
190 191 192 193
        transform = T.Transpose()
        mnist = FashionMNIST(mode='train', transform=transform)
        self.assertTrue(len(mnist) == 60000)

194 195 196 197 198 199 200
        i = np.random.randint(0, len(mnist) - 1)
        image, label = mnist[i]
        self.assertTrue(image.shape[0] == 1)
        self.assertTrue(image.shape[1] == 28)
        self.assertTrue(image.shape[2] == 28)
        self.assertTrue(label.shape[0] == 1)
        self.assertTrue(0 <= int(label) <= 9)
L
LielinJiang 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

        # test cv2 backend
        mnist = FashionMNIST(mode='train', transform=transform, backend='cv2')
        self.assertTrue(len(mnist) == 60000)

        for i in range(len(mnist)):
            image, label = mnist[i]
            self.assertTrue(image.shape[0] == 1)
            self.assertTrue(image.shape[1] == 28)
            self.assertTrue(image.shape[2] == 28)
            self.assertTrue(label.shape[0] == 1)
            self.assertTrue(0 <= int(label) <= 9)
            break

        with self.assertRaises(ValueError):
            mnist = FashionMNIST(mode='train', transform=transform, backend=1)

W
wanghuancoder 已提交
218 219 220 221 222 223
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

    def func_test_dataset_value(self):
L
LielinJiang 已提交
224 225 226 227 228 229
        fmnist = FashionMNIST(mode='train')
        value = np.mean([np.array(x[0]) for x in fmnist])

        # 72.94035223214286 was getted from competitive products
        np.testing.assert_allclose(value, 72.94035223214286)

W
wanghuancoder 已提交
230 231 232 233 234
    def test_dataset_value(self):
        with _test_eager_guard():
            self.func_test_dataset_value()
        self.func_test_dataset_value()

L
LielinJiang 已提交
235

236
class TestFlowersTrain(unittest.TestCase):
W
wanghuancoder 已提交
237
    def func_test_main(self):
238 239 240 241 242 243 244
        flowers = Flowers(mode='train')
        self.assertTrue(len(flowers) == 6149)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 6149)
        image, label = flowers[idx]
245
        image = np.array(image)
246 247 248 249
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

W
wanghuancoder 已提交
250 251 252 253 254
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

255 256

class TestFlowersValid(unittest.TestCase):
W
wanghuancoder 已提交
257
    def func_test_main(self):
258 259 260 261 262 263 264
        flowers = Flowers(mode='valid')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]
265
        image = np.array(image)
266 267 268 269
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

W
wanghuancoder 已提交
270 271 272 273 274
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

275 276

class TestFlowersTest(unittest.TestCase):
W
wanghuancoder 已提交
277
    def func_test_main(self):
278 279 280 281 282 283 284
        flowers = Flowers(mode='test')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]
285
        image = np.array(image)
286 287 288 289
        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        # test cv2 backend
        flowers = Flowers(mode='test', backend='cv2')
        self.assertTrue(len(flowers) == 1020)

        # traversal whole dataset may cost a
        # long time, randomly check 1 sample
        idx = np.random.randint(0, 1020)
        image, label = flowers[idx]

        self.assertTrue(len(image.shape) == 3)
        self.assertTrue(image.shape[2] == 3)
        self.assertTrue(label.shape[0] == 1)

        with self.assertRaises(ValueError):
            flowers = Flowers(mode='test', backend=1)

W
wanghuancoder 已提交
306 307 308 309 310
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

311 312 313

if __name__ == '__main__':
    unittest.main()