test_dataloader_dataset.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import sys
16 17
import unittest

18 19
import paddle
import paddle.vision.transforms as transforms
20
from paddle.fluid.framework import _test_eager_guard
21
from paddle.io import Dataset
22 23 24


class TestDatasetAbstract(unittest.TestCase):
W
wanghuancoder 已提交
25
    def func_test_main(self):
26 27 28 29 30 31 32 33 34 35 36 37 38
        dataset = Dataset()
        try:
            d = dataset[0]
            self.assertTrue(False)
        except NotImplementedError:
            pass

        try:
            l = len(dataset)
            self.assertTrue(False)
        except NotImplementedError:
            pass

W
wanghuancoder 已提交
39 40 41 42 43
    def test_main(self):
        with _test_eager_guard():
            self.func_test_main()
        self.func_test_main()

44

45 46 47
class TestDatasetWithDiffOutputPlace(unittest.TestCase):
    def get_dataloader(self, num_workers):
        dataset = paddle.vision.datasets.MNIST(
48
            mode='test',
49 50 51 52 53 54 55 56 57 58 59 60
            transform=transforms.Compose(
                [
                    transforms.CenterCrop(20),
                    transforms.RandomResizedCrop(14),
                    transforms.Normalize(),
                    transforms.ToTensor(),
                ]
            ),
        )
        loader = paddle.io.DataLoader(
            dataset, batch_size=32, num_workers=num_workers, shuffle=True
        )
61 62 63 64
        return loader

    def run_check_on_cpu(self):
        paddle.set_device('cpu')
65
        loader = self.get_dataloader(1)
66 67 68 69 70
        for image, label in loader:
            self.assertTrue(image.place.is_cpu_place())
            self.assertTrue(label.place.is_cpu_place())
            break

W
wanghuancoder 已提交
71
    def func_test_single_process(self):
72 73 74 75 76 77 78 79 80
        self.run_check_on_cpu()
        if paddle.is_compiled_with_cuda():
            # Get (image, label) tuple from MNIST dataset
            # - the image is on CUDAPlace, label is on CPUPlace
            paddle.set_device('gpu')
            loader = self.get_dataloader(0)
            for image, label in loader:
                self.assertTrue(image.place.is_gpu_place())
                self.assertTrue(label.place.is_cuda_pinned_place())
81
                break
82

W
wanghuancoder 已提交
83 84 85 86 87 88
    def test_single_process(self):
        with _test_eager_guard():
            self.func_test_single_process()
        self.func_test_single_process()

    def func_test_multi_process(self):
89 90 91 92 93 94 95 96 97 98 99 100 101
        # DataLoader with multi-process mode is not supported on MacOs and Windows currently
        if sys.platform != 'darwin' and sys.platform != 'win32':
            self.run_check_on_cpu()
            if paddle.is_compiled_with_cuda():
                # Get (image, label) tuple from MNIST dataset
                # - the image and label are on CPUPlace
                paddle.set_device('gpu')
                loader = self.get_dataloader(1)
                for image, label in loader:
                    self.assertTrue(image.place.is_cuda_pinned_place())
                    self.assertTrue(label.place.is_cuda_pinned_place())
                    break

W
wanghuancoder 已提交
102 103 104 105 106
    def test_multi_process(self):
        with _test_eager_guard():
            self.func_test_multi_process()
        self.func_test_multi_process()

107

108 109
if __name__ == '__main__':
    unittest.main()