未验证 提交 cee0079a 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix readers bug, test=develop (#19868)

上级 747d4498
...@@ -55,11 +55,11 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -55,11 +55,11 @@ def reader_creator(filename, sub_name, cycle=False):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
with tarfile.open(filename, mode='r') as f: while True:
names = (each_item.name for each_item in f with tarfile.open(filename, mode='r') as f:
if sub_name in each_item.name) names = (each_item.name for each_item in f
if sub_name in each_item.name)
while True:
for name in names: for name in names:
if six.PY2: if six.PY2:
batch = pickle.load(f.extractfile(name)) batch = pickle.load(f.extractfile(name))
...@@ -68,8 +68,9 @@ def reader_creator(filename, sub_name, cycle=False): ...@@ -68,8 +68,9 @@ def reader_creator(filename, sub_name, cycle=False):
f.extractfile(name), encoding='bytes') f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
if not cycle:
break if not cycle:
break
return reader return reader
......
...@@ -38,6 +38,7 @@ endif() ...@@ -38,6 +38,7 @@ endif()
if(WIN32) if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_boxps) LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc) LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
endif() endif()
LIST(REMOVE_ITEM TEST_OPS test_launch) LIST(REMOVE_ITEM TEST_OPS test_launch)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
class TestCifar10(unittest.TestCase):
def test_main(self):
reader = paddle.dataset.cifar.train10(cycle=False)
sample_num = 0
for _ in reader():
sample_num += 1
cyclic_reader = paddle.dataset.cifar.train10(cycle=True)
read_num = 0
for data in cyclic_reader():
read_num += 1
self.assertEquals(len(data), 2)
if read_num == sample_num * 2:
break
self.assertEquals(read_num, sample_num * 2)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.io import multiprocess_reader
import unittest
import numpy as np
import six
import sys
class TestMultiprocessReaderException(unittest.TestCase):
def setUp(self):
self.use_pipe = False
self.raise_exception = False
def places(self):
if fluid.is_compiled_with_cuda():
return [fluid.CPUPlace(), fluid.CUDAPlace(0)]
else:
return [fluid.CPUPlace()]
def main_impl(self, place, iterable):
def fake_reader():
def __impl__():
for _ in range(40):
if not self.raise_exception:
yield list(
np.random.uniform(
low=-1, high=1, size=[10])),
else:
raise ValueError()
return __impl__
with fluid.program_guard(fluid.Program(), fluid.Program()):
image = fluid.layers.data(name='image', dtype='float32', shape=[10])
reader = fluid.io.PyReader(
feed_list=[image], capacity=2, iterable=iterable)
image_p_1 = image + 1
decorated_reader = multiprocess_reader(
[fake_reader(), fake_reader()], use_pipe=self.use_pipe)
if isinstance(place, fluid.CUDAPlace):
reader.decorate_sample_generator(
decorated_reader, batch_size=4, places=fluid.cuda_places())
else:
reader.decorate_sample_generator(
decorated_reader, batch_size=4, places=fluid.cpu_places())
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if iterable:
for _ in range(3):
num = 0
for data in reader():
exe.run(feed=data, fetch_list=[image_p_1])
num += 1
if not self.raise_exception:
self.assertEquals(num, 20)
else:
self.assertEquals(num, 0)
raise ValueError('Reader raises exception')
else:
for _ in range(3):
num = 0
reader.start()
try:
while True:
exe.run(fetch_list=[image_p_1])
num += 1
except fluid.core.EOFException:
reader.reset()
if not self.raise_exception:
self.assertEquals(num, 20)
else:
self.assertEquals(num, 0)
raise ValueError('Reader raises exception')
def test_main(self):
for p in self.places():
for iterable in [False, True]:
try:
with fluid.scope_guard(fluid.Scope()):
self.main_impl(p, iterable)
self.assertTrue(not self.raise_exception)
except ValueError:
self.assertTrue(self.raise_exception)
class TestCase1(TestMultiprocessReaderException):
def setUp(self):
self.use_pipe = False
self.raise_exception = True
class TestCase2(TestMultiprocessReaderException):
def setUp(self):
self.use_pipe = True
self.raise_exception = False
class TestCase3(TestMultiprocessReaderException):
def setUp(self):
self.use_pipe = True
self.raise_exception = True
if __name__ == '__main__':
unittest.main()
...@@ -21,6 +21,7 @@ __all__ = [ ...@@ -21,6 +21,7 @@ __all__ = [
from threading import Thread from threading import Thread
import subprocess import subprocess
import multiprocessing import multiprocessing
import six
import sys import sys
from six.moves.queue import Queue from six.moves.queue import Queue
...@@ -390,11 +391,15 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -390,11 +391,15 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
assert type(readers) is list and len(readers) > 0 assert type(readers) is list and len(readers) > 0
def _read_into_queue(reader, queue): def _read_into_queue(reader, queue):
for sample in reader(): try:
if sample is None: for sample in reader():
raise ValueError("sample has None") if sample is None:
queue.put(sample) raise ValueError("sample has None")
queue.put(None) queue.put(sample)
queue.put(None)
except:
queue.put("")
six.reraise(*sys.exc_info())
def queue_reader(): def queue_reader():
queue = multiprocessing.Queue(queue_size) queue = multiprocessing.Queue(queue_size)
...@@ -409,16 +414,23 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -409,16 +414,23 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
sample = queue.get() sample = queue.get()
if sample is None: if sample is None:
finish_num += 1 finish_num += 1
elif sample == "":
raise ValueError("multiprocess reader raises an exception")
else: else:
yield sample yield sample
def _read_into_pipe(reader, conn): def _read_into_pipe(reader, conn):
for sample in reader(): try:
if sample is None: for sample in reader():
raise ValueError("sample has None!") if sample is None:
conn.send(json.dumps(sample)) raise ValueError("sample has None!")
conn.send(json.dumps(None)) conn.send(json.dumps(sample))
conn.close() conn.send(json.dumps(None))
conn.close()
except:
conn.send(json.dumps(""))
conn.close()
six.reraise(*sys.exc_info())
def pipe_reader(): def pipe_reader():
conns = [] conns = []
...@@ -442,6 +454,10 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): ...@@ -442,6 +454,10 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
finish_num += 1 finish_num += 1
conn.close() conn.close()
conn_to_remove.append(conn) conn_to_remove.append(conn)
elif sample == "":
conn.close()
conn_to_remove.append(conn)
raise ValueError("multiprocess reader raises an exception")
else: else:
yield sample yield sample
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册