From 3fab4f65a46f6393de1238b808445dbbb0c3fc33 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Jul 2018 00:01:45 +0800 Subject: [PATCH] Add EOFException to represent EOF in C++ reader --- .../details/data_balance_op_handle.cc | 2 +- .../details/threaded_ssa_graph_executor.cc | 21 ++++++++++++++++--- .../details/threaded_ssa_graph_executor.h | 2 +- paddle/fluid/operators/read_op.cc | 2 +- paddle/fluid/platform/enforce.h | 16 +++++++++++++- paddle/fluid/pybind/exception.cc | 3 +++ .../tests/unittests/test_data_balance.py | 6 ++---- .../tests/unittests/test_multi_file_reader.py | 3 +-- .../tests/unittests/test_multi_pass_reader.py | 3 +-- .../tests/unittests/test_recordio_reader.py | 3 +-- 10 files changed, 44 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index b914851fe..d07235df5 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -62,7 +62,7 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } if (total_size < device_num) { // No enough data. - PADDLE_THROW("There is no next data."); + PADDLE_THROW_EOF(); } std::sort(size_device_vec.begin(), size_device_vec.end(), [](const std::array &a, const std::array &b) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index b1706eb12..99b10254a 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -98,9 +98,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( if (timeout) { std::lock_guard l(exception_mu_); if (exception_) { - auto exp = *exception_; - exception_.reset(); - throw exp; + std::exception *exp = exception_.get(); + if (dynamic_cast(exp)) { + auto e = *static_cast(exp); + exception_.reset(); + throw e; + } else if (dynamic_cast(exp)) { + auto e = *static_cast(exp); + exception_.reset(); + throw e; + } else { + LOG(FATAL) << "Unknown exception."; + } } else { continue; } @@ -199,6 +208,12 @@ void ThreadedSSAGraphExecutor::RunOp( running_ops_--; ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; + } catch (platform::EOFException ex) { + std::lock_guard l(exception_mu_); + // EOFException will not cover up existing EnforceNotMet. + if (exception_.get() == nullptr) { + exception_.reset(new platform::EOFException(ex)); + } } catch (platform::EnforceNotMet ex) { std::lock_guard l(exception_mu_); exception_.reset(new platform::EnforceNotMet(ex)); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 90430be99..c69e0487e 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -57,7 +57,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::vector places_; platform::DeviceContextPool fetch_ctxs_; std::mutex exception_mu_; - std::unique_ptr exception_; + std::unique_ptr exception_; std::atomic running_ops_; void InsertPendingOp(std::unordered_map *pending_ops, diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 60e4eb757..695d7ea83 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -68,7 +68,7 @@ class ReadOp : public framework::OperatorBase { reader->ReadNext(&ins); if (ins.empty()) { if (Attr("throw_eof_exp")) { - PADDLE_THROW("There is no next data."); + PADDLE_THROW_EOF(); } else { ins.resize(out_arg_names.size()); for (auto& tensor : ins) { diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 70bc9c4e8..3790dd135 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -73,7 +73,7 @@ struct EnforceNotMet : public std::exception { } catch (const std::exception& exp) { std::ostringstream sout; - sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; + sout << string::Sprintf("'%s' at [%s:%d]", exp.what(), f, l) << std::endl; sout << "PaddlePaddle Call Stacks: " << std::endl; void* call_stack[TRACE_STACK_LIMIT]; @@ -102,6 +102,15 @@ struct EnforceNotMet : public std::exception { const char* what() const noexcept { return err_str_.c_str(); } }; +struct EOFException : public std::exception { + std::string err_str_; + EOFException(const char* err_msg, const char* f, int l) { + err_str_ = string::Sprintf("'%s' at [%s:%d]", err_msg, f, l); + } + + const char* what() const noexcept { return err_str_.c_str(); } +}; + // Because most enforce conditions would evaluate to true, we can use // __builtin_expect to instruct the C++ compiler to generate code that // always forces branch prediction of true. @@ -242,6 +251,11 @@ inline void throw_on_error(T e) { #define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); #endif +#define PADDLE_THROW_EOF() \ + do { \ + throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ + __LINE__); \ + } while (false) /* * Some enforce helpers here, usage: * int a = 1; diff --git a/paddle/fluid/pybind/exception.cc b/paddle/fluid/pybind/exception.cc index 08a2f185e..831f30e35 100644 --- a/paddle/fluid/pybind/exception.cc +++ b/paddle/fluid/pybind/exception.cc @@ -18,10 +18,13 @@ namespace paddle { namespace pybind { void BindException(pybind11::module* m) { + static pybind11::exception eof(*m, "EOFException"); static pybind11::exception exc(*m, "EnforceNotMet"); pybind11::register_exception_translator([](std::exception_ptr p) { try { if (p) std::rethrow_exception(p); + } catch (const platform::EOFException& e) { + eof(e.what()); } catch (const platform::EnforceNotMet& e) { exc(e.what()); } diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index b558d7c2e..cffa3329a 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -118,8 +118,7 @@ class TestDataBalance(unittest.TestCase): try: image_val, label_val = parallel_exe.run(fetch_list, return_numpy=True) - except fluid.core.EnforceNotMet as ex: - self.assertIn("There is no next data.", ex.message) + except fluid.core.EOFException: break ins_num = image_val.shape[0] broadcasted_label = np.ones( @@ -162,8 +161,7 @@ class TestDataBalance(unittest.TestCase): try: ins_tensor, label_tensor = parallel_exe.run( fetch_list, return_numpy=False) - except fluid.core.EnforceNotMet as ex: - self.assertIn("There is no next data.", ex.message) + except fluid.core.EOFException: break ins_val = np.array(ins_tensor) diff --git a/python/paddle/fluid/tests/unittests/test_multi_file_reader.py b/python/paddle/fluid/tests/unittests/test_multi_file_reader.py index 3f940203b..dbd510e64 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_file_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_file_reader.py @@ -64,8 +64,7 @@ class TestMultipleReader(unittest.TestCase): while True: try: img_val, = exe.run(fetch_list=[img]) - except fluid.core.EnforceNotMet as ex: - self.assertIn("There is no next data.", ex.message) + except fluid.core.EOFException: break batch_count += 1 self.assertLessEqual(img_val.shape[0], self.batch_size) diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py index 52e7cc1ff..7fc9f5504 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -59,8 +59,7 @@ class TestMultipleReader(unittest.TestCase): while True: try: img_val, = exe.run(fetch_list=[img]) - except fluid.core.EnforceNotMet as ex: - self.assertIn("There is no next data.", ex.message) + except fluid.core.EOFException: break batch_count += 1 self.assertLessEqual(img_val.shape[0], self.batch_size) diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index f32050014..69a522e27 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -68,8 +68,7 @@ class TestRecordIO(unittest.TestCase): while True: try: tmp, = exe.run(fetch_list=[avg_loss]) - except fluid.core.EnforceNotMet as ex: - self.assertIn("There is no next data.", ex.message) + except fluid.core.EOFException: break avg_loss_np.append(tmp) -- GitLab