未验证 提交 40effc61 编写于 作者: Z Zeng Jinle 提交者: GitHub

Refine py_reader exit (#20331)

* refine py_reader exit, test=develop

* fix multiprocess_reader exception unittest, test=develop

* increase code coverage for legacy fluid.layers.py_reader, test=develop
上级 a9c8bdad
......@@ -33,7 +33,7 @@ class BlockingQueue {
// doesn't support GPU and it implements on buffered blocking queue.
public:
explicit BlockingQueue(size_t capacity, bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode), closed_(false) {
: capacity_(capacity), speed_test_mode_(speed_test_mode) {
PADDLE_ENFORCE_GT(
capacity_, static_cast<size_t>(0),
"The capacity of a reader::BlockingQueue must be greater than 0.");
......@@ -41,7 +41,9 @@ class BlockingQueue {
bool Send(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled();
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
......@@ -55,7 +57,9 @@ class BlockingQueue {
bool Send(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
send_cv_.wait(
lock, [&] { return queue_.size() < capacity_ || closed_ || killed_; });
EnforceNotKilled();
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
......@@ -69,7 +73,9 @@ class BlockingQueue {
bool Receive(T* elem) {
std::unique_lock<std::mutex> lock(mutex_);
receive_cv_.wait(lock, [&] { return !queue_.empty() || closed_; });
receive_cv_.wait(lock,
[&] { return !queue_.empty() || closed_ || killed_; });
EnforceNotKilled();
if (!queue_.empty()) {
PADDLE_ENFORCE_NOT_NULL(elem);
*elem = queue_.front();
......@@ -87,6 +93,7 @@ class BlockingQueue {
void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_);
EnforceNotKilled();
VLOG(1) << "reopen queue";
closed_ = false;
std::deque<T> new_deque;
......@@ -118,10 +125,27 @@ class BlockingQueue {
return queue_.size();
}
void Kill() {
std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "kill queue";
closed_ = true;
killed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
}
private:
inline void EnforceNotKilled() {
PADDLE_ENFORCE_NE(
killed_, true,
"Blocking queue is killed because the data reader raises an exception");
}
private:
size_t capacity_;
bool speed_test_mode_;
bool closed_;
bool closed_{false};
bool killed_{false}; // the queue is broken since exception raises
std::deque<T> queue_;
mutable std::mutex mutex_;
......
......@@ -26,7 +26,10 @@ BufferedReader::~BufferedReader() {
VLOG(1) << "~BufferedReader";
reader_->Shutdown();
while (!position_.empty()) {
position_.front().wait();
auto &front = position_.front();
if (front.valid()) {
front.wait();
}
position_.pop();
}
#ifdef PADDLE_WITH_CUDA
......
......@@ -65,6 +65,8 @@ class LoDTensorBlockingQueue {
inline bool IsClosed() const { return queue_.IsClosed(); }
inline void Kill() { queue_.Kill(); }
private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_;
};
......
......@@ -860,6 +860,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("kill", &LoDTensorBlockingQueue::Kill)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/pybind/reader_py.h"
#include <exception>
#include <memory>
#include <string>
#include <unordered_map>
......@@ -30,12 +31,6 @@ namespace pybind {
namespace py = pybind11;
static void RaiseStopIterationException() {
VLOG(2) << "Raise StopIteration Exception in Python";
py::gil_scoped_acquire guard;
throw py::stop_iteration();
}
class MultiDeviceFeedReader {
public:
using ResultDictList =
......@@ -71,17 +66,12 @@ class MultiDeviceFeedReader {
futures_.resize(dst_places.size());
ret_.resize(dst_places.size());
exceptions_.assign(dst_places.size(), nullptr);
ReadAsync();
}
ResultDictList ReadNext() {
bool success = WaitFutures();
if (!success) {
RaiseStopIterationException();
return {};
}
CheckNextStatus();
ResultDictList result(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
for (size_t j = 0; j < names_.size(); ++j) {
......@@ -93,12 +83,7 @@ class MultiDeviceFeedReader {
}
ResultList ReadNextList() {
bool success = WaitFutures();
if (!success) {
RaiseStopIterationException();
return {};
}
CheckNextStatus();
ResultList result;
result.reserve(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
......@@ -120,12 +105,32 @@ class MultiDeviceFeedReader {
}
private:
bool WaitFutures() {
bool success = true;
for (auto &f : futures_) {
success &= f.get();
enum Status {
kSuccess = 0, // Read next data successfully
kEOF = 1, // Reach EOF
kException = 2 // Exception raises when reading
};
Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr;
for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i];
exceptions_[i] = nullptr;
}
}
}
if (UNLIKELY(*excep)) {
return Status::kException;
} else {
return is_success ? Status::kSuccess : Status::kEOF;
}
return success;
}
void Shutdown() {
......@@ -139,19 +144,44 @@ class MultiDeviceFeedReader {
void ReadAsync() {
for (size_t i = 0; i < readers_.size(); ++i) {
futures_[i] = pool_->enqueue([this, i] {
try {
readers_[i]->ReadNext(&ret_[i]);
return !ret_[i].empty();
return ret_[i].empty() ? Status::kEOF : Status::kSuccess;
} catch (...) {
exceptions_[i] = std::current_exception();
return Status::kException;
}
});
}
}
void CheckNextStatus() {
std::exception_ptr excep;
Status status = WaitFutures(&excep);
if (UNLIKELY(excep)) {
PADDLE_ENFORCE_EQ(status, Status::kException);
std::rethrow_exception(excep);
}
if (UNLIKELY(status == Status::kEOF)) {
VLOG(2) << "Raise StopIteration Exception in Python";
py::gil_scoped_acquire guard;
throw py::stop_iteration();
}
PADDLE_ENFORCE_EQ(status, Status::kSuccess);
}
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
std::vector<std::string> names_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
std::vector<std::future<bool>> futures_;
std::vector<std::future<Status>> futures_;
std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_;
};
......
......@@ -469,7 +469,7 @@ def _py_reader(capacity,
break
feed_queue.close()
except Exception as ex:
feed_queue.close()
feed_queue.kill()
logging.warn('Your decorated reader has raised an exception!')
six.reraise(*sys.exc_info())
......
......@@ -482,7 +482,7 @@ class GeneratorLoader(DataLoaderBase):
self._queue.close()
self._thread = None
except Exception as ex:
self._queue.close()
self._queue.kill()
self._thread = None
logging.warn('Your reader has raised an exception!')
six.reraise(*sys.exc_info())
......
......@@ -20,6 +20,10 @@ import six
import sys
class ReaderException(Exception):
pass
class TestMultiprocessReaderException(unittest.TestCase):
def setUp(self):
self.use_pipe = False
......@@ -31,10 +35,13 @@ class TestMultiprocessReaderException(unittest.TestCase):
else:
return [fluid.CPUPlace()]
def main_impl(self, place, iterable):
def main_impl(self, place, iterable, use_legacy_py_reader):
sample_num = 40
batch_size = 4
def fake_reader():
def __impl__():
for _ in range(40):
for _ in range(sample_num):
if not self.raise_exception:
yield list(
np.random.uniform(
......@@ -45,37 +52,54 @@ class TestMultiprocessReaderException(unittest.TestCase):
return __impl__
with fluid.program_guard(fluid.Program(), fluid.Program()):
image = fluid.layers.data(name='image', dtype='float32', shape=[10])
if not use_legacy_py_reader:
image = fluid.data(
name='image', dtype='float32', shape=[None, 10])
reader = fluid.io.PyReader(
feed_list=[image], capacity=2, iterable=iterable)
else:
reader = fluid.layers.py_reader(
capacity=2, shapes=[[-1, 10], ], dtypes=['float32', ])
image = fluid.layers.read_file(reader)
image_p_1 = image + 1
decorated_reader = multiprocess_reader(
[fake_reader(), fake_reader()], use_pipe=self.use_pipe)
if use_legacy_py_reader:
reader.decorate_paddle_reader(
fluid.io.batch(
decorated_reader, batch_size=batch_size))
else:
if isinstance(place, fluid.CUDAPlace):
reader.decorate_sample_generator(
decorated_reader, batch_size=4, places=fluid.cuda_places())
decorated_reader,
batch_size=batch_size,
places=fluid.cuda_places())
else:
reader.decorate_sample_generator(
decorated_reader, batch_size=4, places=fluid.cpu_places())
decorated_reader,
batch_size=batch_size,
places=fluid.cpu_places())
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch_num = int(sample_num * 2 / batch_size)
if iterable:
for _ in range(3):
num = 0
try:
for data in reader():
exe.run(feed=data, fetch_list=[image_p_1])
num += 1
if not self.raise_exception:
self.assertEquals(num, 20)
else:
self.assertEquals(num, batch_num)
except fluid.core.EnforceNotMet as ex:
self.assertEquals(num, 0)
raise ValueError('Reader raises exception')
raise ReaderException()
else:
for _ in range(3):
num = 0
......@@ -86,21 +110,25 @@ class TestMultiprocessReaderException(unittest.TestCase):
num += 1
except fluid.core.EOFException:
reader.reset()
if not self.raise_exception:
self.assertEquals(num, 20)
else:
self.assertFalse(self.raise_exception)
self.assertEquals(num, batch_num)
except fluid.core.EnforceNotMet as ex:
self.assertTrue(self.raise_exception)
self.assertEquals(num, 0)
raise ValueError('Reader raises exception')
raise ReaderException()
def test_main(self):
for p in self.places():
for iterable in [False, True]:
use_legacy_py_reader_range = [False
] if iterable else [False, True]
for use_legacy_py_reader in use_legacy_py_reader_range:
try:
with fluid.scope_guard(fluid.Scope()):
self.main_impl(p, iterable)
self.main_impl(p, iterable, use_legacy_py_reader)
self.assertTrue(not self.raise_exception)
except ValueError:
except ReaderException:
self.assertTrue(self.raise_exception)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册