未验证 提交 ebe24e62 编写于 作者: C Chen Weihang 提交者: GitHub

polish unitest test_multiprocess_reader_exception (#33504)

上级 aa50868f
......@@ -1291,7 +1291,7 @@ class GeneratorLoader(DataLoaderBase):
except Exception as ex:
self._queue.kill()
self._thread = None
logging.warn('Your reader has raised an exception!')
logging.warning('Your reader has raised an exception!')
six.reraise(*sys.exc_info())
self._thread = threading.Thread(
......
......@@ -25,7 +25,7 @@ class ReaderException(Exception):
pass
class TestMultiprocessReaderException(unittest.TestCase):
class TestMultiprocessReaderExceptionWithQueueSuccess(unittest.TestCase):
def setUp(self):
self.use_pipe = False
self.raise_exception = False
......@@ -36,7 +36,7 @@ class TestMultiprocessReaderException(unittest.TestCase):
else:
return [fluid.CPUPlace()]
def main_impl(self, place, iterable, use_legacy_py_reader):
def main_impl(self, place, iterable):
sample_num = 40
batch_size = 4
......@@ -53,37 +53,25 @@ class TestMultiprocessReaderException(unittest.TestCase):
return __impl__
with fluid.program_guard(fluid.Program(), fluid.Program()):
if not use_legacy_py_reader:
image = fluid.data(
name='image', dtype='float32', shape=[None, 10])
reader = fluid.io.PyReader(
image = fluid.data(name='image', dtype='float32', shape=[None, 10])
reader = fluid.io.DataLoader.from_generator(
feed_list=[image], capacity=2, iterable=iterable)
else:
reader = fluid.layers.py_reader(
capacity=2, shapes=[[-1, 10], ], dtypes=['float32', ])
image = fluid.layers.read_file(reader)
image_p_1 = image + 1
decorated_reader = multiprocess_reader(
[fake_reader(), fake_reader()], use_pipe=self.use_pipe)
if use_legacy_py_reader:
reader.decorate_paddle_reader(
fluid.io.batch(
decorated_reader, batch_size=batch_size))
else:
if isinstance(place, fluid.CUDAPlace):
reader.decorate_sample_generator(
reader.set_sample_generator(
decorated_reader,
batch_size=batch_size,
places=fluid.cuda_places(0))
else:
reader.decorate_sample_generator(
reader.set_sample_generator(
decorated_reader,
batch_size=batch_size,
places=fluid.cpu_places())
places=fluid.cpu_places(1))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -97,9 +85,9 @@ class TestMultiprocessReaderException(unittest.TestCase):
for data in reader():
exe.run(feed=data, fetch_list=[image_p_1])
num += 1
self.assertEquals(num, batch_num)
self.assertEqual(num, batch_num)
except SystemError as ex:
self.assertEquals(num, 0)
self.assertEqual(num, 0)
raise ReaderException()
else:
for _ in range(3):
......@@ -112,40 +100,40 @@ class TestMultiprocessReaderException(unittest.TestCase):
except fluid.core.EOFException:
reader.reset()
self.assertFalse(self.raise_exception)
self.assertEquals(num, batch_num)
self.assertEqual(num, batch_num)
except SystemError as ex:
self.assertTrue(self.raise_exception)
self.assertEquals(num, 0)
self.assertEqual(num, 0)
raise ReaderException()
def test_main(self):
for p in self.places():
for iterable in [False, True]:
use_legacy_py_reader_range = [False
] if iterable else [False, True]
for use_legacy_py_reader in use_legacy_py_reader_range:
try:
with fluid.scope_guard(fluid.Scope()):
self.main_impl(p, iterable, use_legacy_py_reader)
self.main_impl(p, iterable)
self.assertTrue(not self.raise_exception)
except ReaderException:
self.assertTrue(self.raise_exception)
class TestCase1(TestMultiprocessReaderException):
class TestMultiprocessReaderExceptionWithQueueFailed(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self):
self.use_pipe = False
self.raise_exception = True
class TestCase2(TestMultiprocessReaderException):
class TestMultiprocessReaderExceptionWithPipeSuccess(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self):
self.use_pipe = True
self.raise_exception = False
class TestCase3(TestMultiprocessReaderException):
class TestMultiprocessReaderExceptionWithPipeFailed(
TestMultiprocessReaderExceptionWithQueueSuccess):
def setUp(self):
self.use_pipe = True
self.raise_exception = True
......
......@@ -18,6 +18,7 @@ import multiprocessing
import six
import sys
import warnings
import logging
from six.moves.queue import Queue
from six.moves import zip_longest
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册