diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b9535ee493892667c99d3da35c8ab8462c4e589e..4d8bd101258664f6cafd71784ae070e0cb8b9215 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -101,6 +101,8 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) +cc_test(exception_holder_test SRCS exception_holder_test.cc ) + set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass fuse_elewise_add_act_pass fuse_bn_act_pass diff --git a/paddle/fluid/framework/details/exception_holder.h b/paddle/fluid/framework/details/exception_holder.h index 6bb5a2954b17beba1703f9eacd4bf36bf58faa8c..25c62877bf7127fee7df80bc30546e733eb4286f 100644 --- a/paddle/fluid/framework/details/exception_holder.h +++ b/paddle/fluid/framework/details/exception_holder.h @@ -15,9 +15,11 @@ #pragma once #include +#include // NOLINT #include #include "glog/logging.h" +#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -29,15 +31,16 @@ class ExceptionHolder { void Catch(std::exception_ptr eptr) { try { std::rethrow_exception(eptr); + } catch (memory::allocation::BadAlloc& exp) { + Catch(exp); } catch (platform::EOFException& exp) { Catch(exp); } catch (platform::EnforceNotMet& exp) { Catch(exp); } catch (std::exception& ex) { - PADDLE_THROW(platform::errors::Fatal( - "Unknown std::exception caught:\n%s.", ex.what())); + Catch(ex); } catch (...) { - PADDLE_THROW(platform::errors::Fatal("Unknown exception caught.")); + LOG(FATAL) << "Unknown exception caught."; } } @@ -59,6 +62,15 @@ class ExceptionHolder { auto e = *static_cast(exception_.get()); throw e; } + case kBadAlloc: { + auto e = *static_cast( + exception_.get()); + throw e; + } + case kBaseException: { + auto e = *static_cast(exception_.get()); + throw e; + } } ClearImpl(); } @@ -79,6 +91,12 @@ class ExceptionHolder { case kEOF: { return "EOF"; } + case kBadAlloc: { + return "BadAlloc"; + } + case kBaseException: { + return "BaseException"; + } } return "unknown"; } @@ -95,16 +113,39 @@ class ExceptionHolder { type_ = kEnforceNotMet; } + void Catch(const memory::allocation::BadAlloc& exp) { + std::lock_guard lock(mu_); + // BadAlloc have the highest priority + if (exception_.get() != nullptr) { + VLOG(2) << "exception is reset by BadAlloc, the original error message is" + << exception_->what(); + } + exception_.reset(new paddle::memory::allocation::BadAlloc(exp)); + type_ = kBadAlloc; + } + void Catch(const platform::EOFException& exp) { std::lock_guard lock(mu_); // EOFException will not cover up existing EnforceNotMet. if (exception_.get() == nullptr) { exception_.reset(new platform::EOFException(exp)); type_ = kEOF; + } else { + VLOG(2) << "EOFException is skip, the error message of EOFException is " + << exception_->what(); + } + } + + void Catch(const std::exception& exp) { + std::lock_guard lock(mu_); + // std::exception will not cover anything + if (exception_.get() == nullptr) { + exception_.reset(new std::exception(exp)); + type_ = kBaseException; } } - enum ExceptionType { kNone, kEnforceNotMet, kEOF }; + enum ExceptionType { kNone, kEnforceNotMet, kEOF, kBadAlloc, kBaseException }; ExceptionType type_{kNone}; std::unique_ptr exception_; diff --git a/paddle/fluid/framework/details/exception_holder_test.cc b/paddle/fluid/framework/details/exception_holder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..48a250a331dc61d45394b894765cadb814243685 --- /dev/null +++ b/paddle/fluid/framework/details/exception_holder_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/exception_holder.h" +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/memory/allocation/allocator.h" + +namespace paddle { +namespace framework { +namespace details { +namespace f = paddle::framework; +namespace p = paddle::platform; + +TEST(ExceptionHolderTester, TestBadAllocCatch) { + ExceptionHolder exception_holder; + + try { + throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0); + } catch (...) { + exception_holder.Catch(std::current_exception()); + } + ASSERT_TRUE(exception_holder.IsCaught()); + ASSERT_EQ(exception_holder.Type(), "BadAlloc"); + + bool catch_bad_alloc = false; + try { + exception_holder.ReThrow(); + } catch (memory::allocation::BadAlloc& ex) { + catch_bad_alloc = true; + } catch (...) { + catch_bad_alloc = false; + } + + ASSERT_TRUE(catch_bad_alloc); +} + +TEST(ExceptionHolderTester, TestBaseExpceptionCatch) { + ExceptionHolder exception_holder; + + try { + throw std::exception(); + } catch (...) { + exception_holder.Catch(std::current_exception()); + } + ASSERT_TRUE(exception_holder.IsCaught()); + ASSERT_EQ(exception_holder.Type(), "BaseException"); + + bool catch_base_exception = false; + try { + exception_holder.ReThrow(); + } catch (std::exception& ex) { + catch_base_exception = true; + } catch (...) { + catch_base_exception = false; + } + + ASSERT_TRUE(catch_base_exception); +} + +TEST(ExceptionHolderTester, TestBadAllocCatchReplace) { + ExceptionHolder exception_holder; + try { + throw std::exception(); + } catch (...) { + exception_holder.Catch(std::current_exception()); + } + ASSERT_TRUE(exception_holder.IsCaught()); + ASSERT_EQ(exception_holder.Type(), "BaseException"); + + try { + throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0); + } catch (...) { + exception_holder.Catch(std::current_exception()); + } + ASSERT_TRUE(exception_holder.IsCaught()); + ASSERT_EQ(exception_holder.Type(), "BadAlloc"); + + try { + throw platform::EOFException("eof test", "test_file", 0); + } catch (...) { + exception_holder.Catch(std::current_exception()); + } + ASSERT_EQ(exception_holder.Type(), "BadAlloc"); +} + +} // namespace details +} // namespace framework +} // namespace paddle