未验证 提交 ebe24e62 编写于 作者: C Chen Weihang 提交者: GitHub

polish unitest test_multiprocess_reader_exception (#33504)

上级 aa50868f
...@@ -1291,7 +1291,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1291,7 +1291,7 @@ class GeneratorLoader(DataLoaderBase):
except Exception as ex: except Exception as ex:
self._queue.kill() self._queue.kill()
self._thread = None self._thread = None
logging.warn('Your reader has raised an exception!') logging.warning('Your reader has raised an exception!')
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
self._thread = threading.Thread( self._thread = threading.Thread(
......
...@@ -25,7 +25,7 @@ class ReaderException(Exception): ...@@ -25,7 +25,7 @@ class ReaderException(Exception):
pass pass
class TestMultiprocessReaderException(unittest.TestCase): class TestMultiprocessReaderExceptionWithQueueSuccess(unittest.TestCase):
def setUp(self): def setUp(self):
self.use_pipe = False self.use_pipe = False
self.raise_exception = False self.raise_exception = False
...@@ -36,7 +36,7 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestMultiprocessReaderException(unittest.TestCase):
else: else:
return [fluid.CPUPlace()] return [fluid.CPUPlace()]
def main_impl(self, place, iterable, use_legacy_py_reader): def main_impl(self, place, iterable):
sample_num = 40 sample_num = 40
batch_size = 4 batch_size = 4
...@@ -53,37 +53,25 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -53,37 +53,25 @@ class TestMultiprocessReaderException(unittest.TestCase):
return __impl__ return __impl__
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
if not use_legacy_py_reader: image = fluid.data(name='image', dtype='float32', shape=[None, 10])
image = fluid.data( reader = fluid.io.DataLoader.from_generator(
name='image', dtype='float32', shape=[None, 10])
reader = fluid.io.PyReader(
feed_list=[image], capacity=2, iterable=iterable) feed_list=[image], capacity=2, iterable=iterable)
else:
reader = fluid.layers.py_reader(
capacity=2, shapes=[[-1, 10], ], dtypes=['float32', ])
image = fluid.layers.read_file(reader)
image_p_1 = image + 1 image_p_1 = image + 1
decorated_reader = multiprocess_reader( decorated_reader = multiprocess_reader(
[fake_reader(), fake_reader()], use_pipe=self.use_pipe) [fake_reader(), fake_reader()], use_pipe=self.use_pipe)
if use_legacy_py_reader:
reader.decorate_paddle_reader(
fluid.io.batch(
decorated_reader, batch_size=batch_size))
else:
if isinstance(place, fluid.CUDAPlace): if isinstance(place, fluid.CUDAPlace):
reader.decorate_sample_generator( reader.set_sample_generator(
decorated_reader, decorated_reader,
batch_size=batch_size, batch_size=batch_size,
places=fluid.cuda_places(0)) places=fluid.cuda_places(0))
else: else:
reader.decorate_sample_generator( reader.set_sample_generator(
decorated_reader, decorated_reader,
batch_size=batch_size, batch_size=batch_size,
places=fluid.cpu_places()) places=fluid.cpu_places(1))
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -97,9 +85,9 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -97,9 +85,9 @@ class TestMultiprocessReaderException(unittest.TestCase):
for data in reader(): for data in reader():
exe.run(feed=data, fetch_list=[image_p_1]) exe.run(feed=data, fetch_list=[image_p_1])
num += 1 num += 1
self.assertEquals(num, batch_num) self.assertEqual(num, batch_num)
except SystemError as ex: except SystemError as ex:
self.assertEquals(num, 0) self.assertEqual(num, 0)
raise ReaderException() raise ReaderException()
else: else:
for _ in range(3): for _ in range(3):
...@@ -112,40 +100,40 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -112,40 +100,40 @@ class TestMultiprocessReaderException(unittest.TestCase):
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
self.assertFalse(self.raise_exception) self.assertFalse(self.raise_exception)
self.assertEquals(num, batch_num) self.assertEqual(num, batch_num)
except SystemError as ex: except SystemError as ex:
self.assertTrue(self.raise_exception) self.assertTrue(self.raise_exception)
self.assertEquals(num, 0) self.assertEqual(num, 0)
raise ReaderException() raise ReaderException()
def test_main(self): def test_main(self):
for p in self.places(): for p in self.places():
for iterable in [False, True]: for iterable in [False, True]:
use_legacy_py_reader_range = [False
] if iterable else [False, True]
for use_legacy_py_reader in use_legacy_py_reader_range:
try: try:
with fluid.scope_guard(fluid.Scope()): with fluid.scope_guard(fluid.Scope()):
self.main_impl(p, iterable, use_legacy_py_reader) self.main_impl(p, iterable)
self.assertTrue(not self.raise_exception) self.assertTrue(not self.raise_exception)
except ReaderException: except ReaderException:
self.assertTrue(self.raise_exception) self.assertTrue(self.raise_exception)
class TestCase1(TestMultiprocessReaderException): class TestMultiprocessReaderExceptionWithQueueFailed(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self): def setUp(self):
self.use_pipe = False self.use_pipe = False
self.raise_exception = True self.raise_exception = True
class TestCase2(TestMultiprocessReaderException): class TestMultiprocessReaderExceptionWithPipeSuccess(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self): def setUp(self):
self.use_pipe = True self.use_pipe = True
self.raise_exception = False self.raise_exception = False
class TestCase3(TestMultiprocessReaderException): class TestMultiprocessReaderExceptionWithPipeFailed(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self): def setUp(self):
self.use_pipe = True self.use_pipe = True
self.raise_exception = True self.raise_exception = True
......
...@@ -18,6 +18,7 @@ import multiprocessing ...@@ -18,6 +18,7 @@ import multiprocessing
import six import six
import sys import sys
import warnings import warnings
import logging
from six.moves.queue import Queue from six.moves.queue import Queue
from six.moves import zip_longest from six.moves import zip_longest
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册