未验证 提交 4061aa64 编写于 作者: C Chen Weihang 提交者: GitHub

Polish ParallelExecutor exception process logic (#25449)

* polish pe exception process logic, test=develop

* fix unittest, test=develop

* add unittests, test=develop
上级 914ff10a
...@@ -107,21 +107,31 @@ class ExceptionHolder { ...@@ -107,21 +107,31 @@ class ExceptionHolder {
type_ = kNone; type_ = kNone;
} }
// NOTE: currently in PE, multiple exceptions may occured in multiple
// threads, and the exception that occur later will overwrite that
// occur earlier, but what we want should be the first triggered exception.
// However, EOF exception is lower priority exception and can be overwritten,
// but other exceptions should not be prioritized.
void Catch(const platform::EnforceNotMet& exp) { void Catch(const platform::EnforceNotMet& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
exception_.reset(new platform::EnforceNotMet(exp)); if (exception_.get() == nullptr || type_ == kEOF) {
type_ = kEnforceNotMet; exception_.reset(new platform::EnforceNotMet(exp));
type_ = kEnforceNotMet;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
}
} }
void Catch(const memory::allocation::BadAlloc& exp) { void Catch(const memory::allocation::BadAlloc& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
// BadAlloc have the highest priority if (exception_.get() == nullptr || type_ == kEOF) {
if (exception_.get() != nullptr) { exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
VLOG(2) << "exception is reset by BadAlloc, the original error message is" type_ = kBadAlloc;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what(); << exception_->what();
} }
exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
type_ = kBadAlloc;
} }
void Catch(const platform::EOFException& exp) { void Catch(const platform::EOFException& exp) {
...@@ -138,10 +148,12 @@ class ExceptionHolder { ...@@ -138,10 +148,12 @@ class ExceptionHolder {
void Catch(const std::exception& exp) { void Catch(const std::exception& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
// std::exception will not cover anything if (exception_.get() == nullptr || type_ == kEOF) {
if (exception_.get() == nullptr) {
exception_.reset(new std::exception(exp)); exception_.reset(new std::exception(exp));
type_ = kBaseException; type_ = kBaseException;
} else {
VLOG(2) << "Non-first exception is discarded, the error message is"
<< exception_->what();
} }
} }
......
...@@ -24,6 +24,29 @@ namespace details { ...@@ -24,6 +24,29 @@ namespace details {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
TEST(ExceptionHolderTester, TestEnforceNotMetCatch) {
ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
bool catch_enforce_not_met = false;
try {
exception_holder.ReThrow();
} catch (platform::EnforceNotMet& ex) {
catch_enforce_not_met = true;
} catch (...) {
catch_enforce_not_met = false;
}
ASSERT_TRUE(catch_enforce_not_met);
}
TEST(ExceptionHolderTester, TestBadAllocCatch) { TEST(ExceptionHolderTester, TestBadAllocCatch) {
ExceptionHolder exception_holder; ExceptionHolder exception_holder;
...@@ -70,15 +93,24 @@ TEST(ExceptionHolderTester, TestBaseExpceptionCatch) { ...@@ -70,15 +93,24 @@ TEST(ExceptionHolderTester, TestBaseExpceptionCatch) {
ASSERT_TRUE(catch_base_exception); ASSERT_TRUE(catch_base_exception);
} }
TEST(ExceptionHolderTester, TestBadAllocCatchReplace) { TEST(ExceptionHolderTester, TestExceptionReplace) {
ExceptionHolder exception_holder; ExceptionHolder exception_holder;
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw std::exception(); throw std::exception();
} catch (...) { } catch (...) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_TRUE(exception_holder.IsCaught()); ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BaseException"); ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0); throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
...@@ -86,13 +118,31 @@ TEST(ExceptionHolderTester, TestBadAllocCatchReplace) { ...@@ -86,13 +118,31 @@ TEST(ExceptionHolderTester, TestBadAllocCatchReplace) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_TRUE(exception_holder.IsCaught()); ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc"); ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
try { try {
throw platform::EOFException("eof test", "test_file", 0); throw platform::EOFException("eof test", "test_file", 0);
} catch (...) { } catch (...) {
exception_holder.Catch(std::current_exception()); exception_holder.Catch(std::current_exception());
} }
ASSERT_EQ(exception_holder.Type(), "EnforceNotMet");
exception_holder.Clear();
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
try {
throw platform::EnforceNotMet("enforce not met test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc"); ASSERT_EQ(exception_holder.Type(), "BadAlloc");
} }
......
...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { ...@@ -269,7 +269,14 @@ void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
void FastThreadedSSAGraphExecutor::ExecutionFinal( void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) { std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it"; VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops); // NOTE: If a new exception occurs in this ClearFetchOp operation, it will
// cause the loss of exception triggered firstly not thrown.
// Instead, the cleanup operation should only be performed when an EOF
// exception is caught. If other exceptions are triggered, the ClearFetchOp
// should not be continued.
if (exception_.Type() == "EOF") {
ClearFetchOp(graph_, fetch_ops);
}
exception_.ReThrow(); exception_.ReThrow();
} }
......
...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW { ...@@ -36,7 +36,7 @@ OpHandleBase::~OpHandleBase() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
if (ev.second) { if (ev.second) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
} }
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册