未验证 提交 40effc61 编写于 作者: Z Zeng Jinle 提交者: GitHub

Refine py_reader exit (#20331)

* refine py_reader exit, test=develop

* fix multiprocess_reader exception unittest, test=develop

* increase code coverage for legacy fluid.layers.py_reader, test=develop
上级 a9c8bdad
...@@ -33,7 +33,7 @@ class BlockingQueue { ...@@ -33,7 +33,7 @@ class BlockingQueue {
// doesn't support GPU and it implements on buffered blocking queue. // doesn't support GPU and it implements on buffered blocking queue.
public: public:
explicit BlockingQueue(size_t capacity, bool speed_test_mode = false) explicit BlockingQueue(size_t capacity, bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) { : capacity_(capacity), speed_test_mode_(speed_test_mode) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
capacity_, static_cast<size_t>(0), capacity_, static_cast<size_t>(0),
"The capacity of a reader::BlockingQueue must be greater than 0."); "The capacity of a reader::BlockingQueue must be greater than 0.");
...@@ -41,7 +41,9 @@ class BlockingQueue { ...@@ -41,7 +41,9 @@ class BlockingQueue {
bool Send(const T& elem) { bool Send(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; }); send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled();
if (closed_) { if (closed_) {
VLOG(5) VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue."; << "WARNING: Sending an element to a closed reader::BlokcingQueue.";
...@@ -55,7 +57,9 @@ class BlockingQueue { ...@@ -55,7 +57,9 @@ class BlockingQueue {
bool Send(T&& elem) { bool Send(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; }); send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled();
if (closed_) { if (closed_) {
VLOG(5) VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue."; << "WARNING: Sending an element to a closed reader::BlokcingQueue.";
...@@ -69,7 +73,9 @@ class BlockingQueue { ...@@ -69,7 +73,9 @@ class BlockingQueue {
bool Receive(T* elem) { bool Receive(T* elem) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
receive_cv_.wait(lock, [&] { return !queue_.empty() || closed_; }); receive_cv_.wait(lock,
[&] { return !queue_.empty() || closed_ || killed_; });
EnforceNotKilled();
if (!queue_.empty()) { if (!queue_.empty()) {
PADDLE_ENFORCE_NOT_NULL(elem); PADDLE_ENFORCE_NOT_NULL(elem);
*elem = queue_.front(); *elem = queue_.front();
...@@ -87,6 +93,7 @@ class BlockingQueue { ...@@ -87,6 +93,7 @@ class BlockingQueue {
void ReOpen() { void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
EnforceNotKilled();
VLOG(1) << "reopen queue"; VLOG(1) << "reopen queue";
closed_ = false; closed_ = false;
std::deque<T> new_deque; std::deque<T> new_deque;
...@@ -118,10 +125,27 @@ class BlockingQueue { ...@@ -118,10 +125,27 @@ class BlockingQueue {
return queue_.size(); return queue_.size();
} }
void Kill() {
std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "kill queue";
closed_ = true;
killed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
}
private:
inline void EnforceNotKilled() {
PADDLE_ENFORCE_NE(
killed_, true,
"Blocking queue is killed because the data reader raises an exception");
}
private: private:
size_t capacity_; size_t capacity_;
bool speed_test_mode_; bool speed_test_mode_;
bool closed_; bool closed_{false};
bool killed_{false}; // the queue is broken since exception raises
std::deque<T> queue_; std::deque<T> queue_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
......
...@@ -26,7 +26,10 @@ BufferedReader::~BufferedReader() { ...@@ -26,7 +26,10 @@ BufferedReader::~BufferedReader() {
VLOG(1) << "~BufferedReader"; VLOG(1) << "~BufferedReader";
reader_->Shutdown(); reader_->Shutdown();
while (!position_.empty()) { while (!position_.empty()) {
position_.front().wait(); auto &front = position_.front();
if (front.valid()) {
front.wait();
}
position_.pop(); position_.pop();
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -65,6 +65,8 @@ class LoDTensorBlockingQueue { ...@@ -65,6 +65,8 @@ class LoDTensorBlockingQueue {
inline bool IsClosed() const { return queue_.IsClosed(); } inline bool IsClosed() const { return queue_.IsClosed(); }
inline void Kill() { queue_.Kill(); }
private: private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
}; };
......
...@@ -860,6 +860,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -860,6 +860,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("size", &LoDTensorBlockingQueue::Size) .def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap) .def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close) .def("close", &LoDTensorBlockingQueue::Close)
.def("kill", &LoDTensorBlockingQueue::Kill)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed); .def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/reader_py.h"
#include <exception>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -30,12 +31,6 @@ namespace pybind { ...@@ -30,12 +31,6 @@ namespace pybind {
namespace py = pybind11; namespace py = pybind11;
static void RaiseStopIterationException() {
VLOG(2) << "Raise StopIteration Exception in Python";
py::gil_scoped_acquire guard;
throw py::stop_iteration();
}
class MultiDeviceFeedReader { class MultiDeviceFeedReader {
public: public:
using ResultDictList = using ResultDictList =
...@@ -71,17 +66,12 @@ class MultiDeviceFeedReader { ...@@ -71,17 +66,12 @@ class MultiDeviceFeedReader {
futures_.resize(dst_places.size()); futures_.resize(dst_places.size());
ret_.resize(dst_places.size()); ret_.resize(dst_places.size());
exceptions_.assign(dst_places.size(), nullptr);
ReadAsync(); ReadAsync();
} }
ResultDictList ReadNext() { ResultDictList ReadNext() {
bool success = WaitFutures(); CheckNextStatus();
if (!success) {
RaiseStopIterationException();
return {};
}
ResultDictList result(ret_.size()); ResultDictList result(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) { for (size_t i = 0; i < ret_.size(); ++i) {
for (size_t j = 0; j < names_.size(); ++j) { for (size_t j = 0; j < names_.size(); ++j) {
...@@ -93,12 +83,7 @@ class MultiDeviceFeedReader { ...@@ -93,12 +83,7 @@ class MultiDeviceFeedReader {
} }
ResultList ReadNextList() { ResultList ReadNextList() {
bool success = WaitFutures(); CheckNextStatus();
if (!success) {
RaiseStopIterationException();
return {};
}
ResultList result; ResultList result;
result.reserve(ret_.size()); result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) { for (size_t i = 0; i < ret_.size(); ++i) {
...@@ -120,12 +105,32 @@ class MultiDeviceFeedReader { ...@@ -120,12 +105,32 @@ class MultiDeviceFeedReader {
} }
private: private:
bool WaitFutures() { enum Status {
bool success = true; kSuccess = 0, // Read next data successfully
for (auto &f : futures_) { kEOF = 1, // Reach EOF
success &= f.get(); kException = 2 // Exception raises when reading
};
Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr;
for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i];
exceptions_[i] = nullptr;
}
}
}
if (UNLIKELY(*excep)) {
return Status::kException;
} else {
return is_success ? Status::kSuccess : Status::kEOF;
} }
return success;
} }
void Shutdown() { void Shutdown() {
...@@ -139,19 +144,44 @@ class MultiDeviceFeedReader { ...@@ -139,19 +144,44 @@ class MultiDeviceFeedReader {
void ReadAsync() { void ReadAsync() {
for (size_t i = 0; i < readers_.size(); ++i) { for (size_t i = 0; i < readers_.size(); ++i) {
futures_[i] = pool_->enqueue([this, i] { futures_[i] = pool_->enqueue([this, i] {
readers_[i]->ReadNext(&ret_[i]); try {
return !ret_[i].empty(); readers_[i]->ReadNext(&ret_[i]);
return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
} catch (...) {
exceptions_[i] = std::current_exception();
return Status::kException;
}
}); });
} }
} }
void CheckNextStatus() {
std::exception_ptr excep;
Status status = WaitFutures(&excep);
if (UNLIKELY(excep)) {
PADDLE_ENFORCE_EQ(status, Status::kException);
std::rethrow_exception(excep);
}
if (UNLIKELY(status == Status::kEOF)) {
VLOG(2) << "Raise StopIteration Exception in Python";
py::gil_scoped_acquire guard;
throw py::stop_iteration();
}
PADDLE_ENFORCE_EQ(status, Status::kSuccess);
}
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_; std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
std::vector<std::string> names_; std::vector<std::string> names_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<std::unique_ptr<framework::ReaderHolder>> readers_; std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
std::vector<std::future<bool>> futures_; std::vector<std::future<Status>> futures_;
std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_; std::vector<std::vector<framework::LoDTensor>> ret_;
}; };
......
...@@ -469,7 +469,7 @@ def _py_reader(capacity, ...@@ -469,7 +469,7 @@ def _py_reader(capacity,
break break
feed_queue.close() feed_queue.close()
except Exception as ex: except Exception as ex:
feed_queue.close() feed_queue.kill()
logging.warn('Your decorated reader has raised an exception!') logging.warn('Your decorated reader has raised an exception!')
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
......
...@@ -482,7 +482,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -482,7 +482,7 @@ class GeneratorLoader(DataLoaderBase):
self._queue.close() self._queue.close()
self._thread = None self._thread = None
except Exception as ex: except Exception as ex:
self._queue.close() self._queue.kill()
self._thread = None self._thread = None
logging.warn('Your reader has raised an exception!') logging.warn('Your reader has raised an exception!')
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
......
...@@ -20,6 +20,10 @@ import six ...@@ -20,6 +20,10 @@ import six
import sys import sys
class ReaderException(Exception):
pass
class TestMultiprocessReaderException(unittest.TestCase): class TestMultiprocessReaderException(unittest.TestCase):
def setUp(self): def setUp(self):
self.use_pipe = False self.use_pipe = False
...@@ -31,10 +35,13 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -31,10 +35,13 @@ class TestMultiprocessReaderException(unittest.TestCase):
else: else:
return [fluid.CPUPlace()] return [fluid.CPUPlace()]
def main_impl(self, place, iterable): def main_impl(self, place, iterable, use_legacy_py_reader):
sample_num = 40
batch_size = 4
def fake_reader(): def fake_reader():
def __impl__(): def __impl__():
for _ in range(40): for _ in range(sample_num):
if not self.raise_exception: if not self.raise_exception:
yield list( yield list(
np.random.uniform( np.random.uniform(
...@@ -45,37 +52,54 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -45,37 +52,54 @@ class TestMultiprocessReaderException(unittest.TestCase):
return __impl__ return __impl__
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
image = fluid.layers.data(name='image', dtype='float32', shape=[10]) if not use_legacy_py_reader:
image = fluid.data(
name='image', dtype='float32', shape=[None, 10])
reader = fluid.io.PyReader( reader = fluid.io.PyReader(
feed_list=[image], capacity=2, iterable=iterable) feed_list=[image], capacity=2, iterable=iterable)
else:
reader = fluid.layers.py_reader(
capacity=2, shapes=[[-1, 10], ], dtypes=['float32', ])
image = fluid.layers.read_file(reader)
image_p_1 = image + 1 image_p_1 = image + 1
decorated_reader = multiprocess_reader( decorated_reader = multiprocess_reader(
[fake_reader(), fake_reader()], use_pipe=self.use_pipe) [fake_reader(), fake_reader()], use_pipe=self.use_pipe)
if isinstance(place, fluid.CUDAPlace): if use_legacy_py_reader:
reader.decorate_sample_generator( reader.decorate_paddle_reader(
decorated_reader, batch_size=4, places=fluid.cuda_places()) fluid.io.batch(
decorated_reader, batch_size=batch_size))
else: else:
reader.decorate_sample_generator( if isinstance(place, fluid.CUDAPlace):
decorated_reader, batch_size=4, places=fluid.cpu_places()) reader.decorate_sample_generator(
decorated_reader,
batch_size=batch_size,
places=fluid.cuda_places())
else:
reader.decorate_sample_generator(
decorated_reader,
batch_size=batch_size,
places=fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
batch_num = int(sample_num * 2 / batch_size)
if iterable: if iterable:
for _ in range(3): for _ in range(3):
num = 0 num = 0
for data in reader(): try:
exe.run(feed=data, fetch_list=[image_p_1]) for data in reader():
num += 1 exe.run(feed=data, fetch_list=[image_p_1])
if not self.raise_exception: num += 1
self.assertEquals(num, 20) self.assertEquals(num, batch_num)
else: except fluid.core.EnforceNotMet as ex:
self.assertEquals(num, 0) self.assertEquals(num, 0)
raise ValueError('Reader raises exception') raise ReaderException()
else: else:
for _ in range(3): for _ in range(3):
num = 0 num = 0
...@@ -86,22 +110,26 @@ class TestMultiprocessReaderException(unittest.TestCase): ...@@ -86,22 +110,26 @@ class TestMultiprocessReaderException(unittest.TestCase):
num += 1 num += 1
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
if not self.raise_exception: self.assertFalse(self.raise_exception)
self.assertEquals(num, 20) self.assertEquals(num, batch_num)
else: except fluid.core.EnforceNotMet as ex:
self.assertEquals(num, 0) self.assertTrue(self.raise_exception)
raise ValueError('Reader raises exception') self.assertEquals(num, 0)
raise ReaderException()
def test_main(self): def test_main(self):
for p in self.places(): for p in self.places():
for iterable in [False, True]: for iterable in [False, True]:
try: use_legacy_py_reader_range = [False
with fluid.scope_guard(fluid.Scope()): ] if iterable else [False, True]
self.main_impl(p, iterable) for use_legacy_py_reader in use_legacy_py_reader_range:
try:
with fluid.scope_guard(fluid.Scope()):
self.main_impl(p, iterable, use_legacy_py_reader)
self.assertTrue(not self.raise_exception) self.assertTrue(not self.raise_exception)
except ValueError: except ReaderException:
self.assertTrue(self.raise_exception) self.assertTrue(self.raise_exception)
class TestCase1(TestMultiprocessReaderException): class TestCase1(TestMultiprocessReaderException):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册