test_multiprocess_reader_exception.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
import paddle.fluid as fluid
17
from paddle.reader import multiprocess_reader
18 19 20 21
import unittest
import numpy as np


Z
Zeng Jinle 已提交
22 23 24 25
class ReaderException(Exception):
    pass


26
class TestMultiprocessReaderExceptionWithQueueSuccess(unittest.TestCase):
27

28 29 30 31 32 33 34 35 36 37
    def setUp(self):
        self.use_pipe = False
        self.raise_exception = False

    def places(self):
        if fluid.is_compiled_with_cuda():
            return [fluid.CPUPlace(), fluid.CUDAPlace(0)]
        else:
            return [fluid.CPUPlace()]

38
    def main_impl(self, place, iterable):
Z
Zeng Jinle 已提交
39 40 41
        sample_num = 40
        batch_size = 4

42
        def fake_reader():
43

44
            def __impl__():
Z
Zeng Jinle 已提交
45
                for _ in range(sample_num):
46
                    if not self.raise_exception:
47 48
                        yield list(np.random.uniform(low=-1, high=1,
                                                     size=[10])),
49 50 51 52 53 54
                    else:
                        raise ValueError()

            return __impl__

        with fluid.program_guard(fluid.Program(), fluid.Program()):
55
            image = fluid.data(name='image', dtype='float32', shape=[None, 10])
56 57 58
            reader = fluid.io.DataLoader.from_generator(feed_list=[image],
                                                        capacity=2,
                                                        iterable=iterable)
59 60 61 62 63 64

            image_p_1 = image + 1

            decorated_reader = multiprocess_reader(
                [fake_reader(), fake_reader()], use_pipe=self.use_pipe)

65
            if isinstance(place, fluid.CUDAPlace):
66 67 68
                reader.set_sample_generator(decorated_reader,
                                            batch_size=batch_size,
                                            places=fluid.cuda_places(0))
69
            else:
70 71 72
                reader.set_sample_generator(decorated_reader,
                                            batch_size=batch_size,
                                            places=fluid.cpu_places(1))
73 74 75 76

            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())

Z
Zeng Jinle 已提交
77 78
            batch_num = int(sample_num * 2 / batch_size)

79 80 81
            if iterable:
                for _ in range(3):
                    num = 0
Z
Zeng Jinle 已提交
82 83 84 85
                    try:
                        for data in reader():
                            exe.run(feed=data, fetch_list=[image_p_1])
                            num += 1
86
                        self.assertEqual(num, batch_num)
87
                    except SystemError as ex:
88
                        self.assertEqual(num, 0)
Z
Zeng Jinle 已提交
89
                        raise ReaderException()
90 91 92 93 94 95 96 97 98 99
            else:
                for _ in range(3):
                    num = 0
                    reader.start()
                    try:
                        while True:
                            exe.run(fetch_list=[image_p_1])
                            num += 1
                    except fluid.core.EOFException:
                        reader.reset()
Z
Zeng Jinle 已提交
100
                        self.assertFalse(self.raise_exception)
101
                        self.assertEqual(num, batch_num)
102
                    except SystemError as ex:
Z
Zeng Jinle 已提交
103
                        self.assertTrue(self.raise_exception)
104
                        self.assertEqual(num, 0)
Z
Zeng Jinle 已提交
105
                        raise ReaderException()
106 107 108 109

    def test_main(self):
        for p in self.places():
            for iterable in [False, True]:
110 111 112
                try:
                    with fluid.scope_guard(fluid.Scope()):
                        self.main_impl(p, iterable)
113

114 115 116
                    self.assertTrue(not self.raise_exception)
                except ReaderException:
                    self.assertTrue(self.raise_exception)
117 118


119 120
class TestMultiprocessReaderExceptionWithQueueFailed(
        TestMultiprocessReaderExceptionWithQueueSuccess):
121

122 123 124 125 126
    def setUp(self):
        self.use_pipe = False
        self.raise_exception = True


127 128
class TestMultiprocessReaderExceptionWithPipeSuccess(
        TestMultiprocessReaderExceptionWithQueueSuccess):
129

130 131 132 133 134
    def setUp(self):
        self.use_pipe = True
        self.raise_exception = False


135 136
class TestMultiprocessReaderExceptionWithPipeFailed(
        TestMultiprocessReaderExceptionWithQueueSuccess):
137

138 139 140 141 142 143
    def setUp(self):
        self.use_pipe = True
        self.raise_exception = True


if __name__ == '__main__':
144
    paddle.enable_static()
145
    unittest.main()