未验证 提交 87fbbd36 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] cache exception in child thread (#36692)

* cache exception in child thread

* add ut

* fix ut
上级 fe6dbdd3
......@@ -23,6 +23,8 @@
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor");
constexpr const char* kExceptionCaught = "ExceptionCaught";
namespace paddle {
namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it.
......@@ -42,6 +44,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
feed_names_ = feed_names;
exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
// Step1: add feedop and fetchop to main_program
AddFetch(fetch_names);
......@@ -360,6 +365,8 @@ void InterpreterCore::ExecuteInstructionList(
async_work_queue_.PrepareAtomicVarRef(vec_meta_info_);
op_run_number_ = 0;
exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
async_work_queue_.AddTask(vec_instr[i].type_,
......@@ -370,6 +377,11 @@ void InterpreterCore::ExecuteInstructionList(
auto event_id = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id;
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught " << exception_holder_.Type();
exception_holder_.ReThrow();
}
PADDLE_ENFORCE_EQ(
op_run_number_.load(), vec_instr.size(),
platform::errors::Fatal(
......@@ -441,11 +453,34 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
instr_id = ready_ops.front();
ready_ops.pop();
auto& instr_node = vec_instruction_[instr_id];
platform::RecordEvent instruction_event(
instr_node.kernel_func_.operator_base_->Type());
auto* op = instr_node.kernel_func_.operator_base_;
platform::RecordEvent instruction_event(op->Type());
event_manager_.WaitEvent(instr_node, place_);
RunInstruction(instr_node);
try {
RunInstruction(instr_node);
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
} catch (platform::EOFException&) {
exception_holder_.Catch(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", "
<< ex.what();
exception_holder_.Catch(std::current_exception());
} catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception";
exception_holder_.Catch(std::current_exception());
}
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught";
if (exception_notifier_ != nullptr) {
exception_notifier_->NotifyEvent();
}
return;
}
event_manager_.RecordEvent(instr_node, place_);
op_run_number_.fetch_add(1, std::memory_order_relaxed);
......
......@@ -19,6 +19,7 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
......@@ -26,6 +27,7 @@
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
......@@ -97,6 +99,8 @@ class InterpreterCore {
EventManager event_manager_;
EventsWaiter main_thread_blocker_;
interpretercore::AsyncWorkQueue async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
InterpreterCoreGarbageCollector gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
......
......@@ -22,6 +22,7 @@ namespace paddle {
namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty";
class EventsWaiter;
struct WorkQueueOptions {
......
......@@ -248,5 +248,48 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
class TestException(unittest.TestCase):
def setUp(self):
self.place = paddle.CPUPlace()
def build_program(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
w = paddle.rand([10, 20])
ids = paddle.static.data(name="id", shape=[5], dtype='int64')
emb = paddle.nn.functional.embedding(
x=ids, weight=w, sparse=False, name="embedding")
return main_program, startup_program, emb
def _run(self, feeds):
paddle.seed(2020)
main_program, startup_program, fetch_vars = self.build_program()
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
return out
def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
return out
def test_exception(self):
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64)
}, {
'id': np.array([1, 2, 3, 4, 11]).astype(np.int64)
}]
self.assertRaises(ValueError, self.run_new_executor, feed)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册