未验证 提交 70d7d07f 编写于 作者: H hong 提交者: GitHub

catch bad alloc exception (#25140)

* cat bad alloc exception; test=develop

* add unitest; test=develop

* move bad alloc catch to the first place; test=develop

* polish error message; test=develop

* polish error message; test=develop

* add mutex header; test=develop
上级 30185efd
...@@ -101,6 +101,8 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo ...@@ -101,6 +101,8 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_test(exception_holder_test SRCS exception_holder_test.cc )
set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass fuse_elewise_add_act_pass fuse_bn_act_pass
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -29,15 +31,16 @@ class ExceptionHolder { ...@@ -29,15 +31,16 @@ class ExceptionHolder {
void Catch(std::exception_ptr eptr) { void Catch(std::exception_ptr eptr) {
try { try {
std::rethrow_exception(eptr); std::rethrow_exception(eptr);
} catch (memory::allocation::BadAlloc& exp) {
Catch(exp);
} catch (platform::EOFException& exp) { } catch (platform::EOFException& exp) {
Catch(exp); Catch(exp);
} catch (platform::EnforceNotMet& exp) { } catch (platform::EnforceNotMet& exp) {
Catch(exp); Catch(exp);
} catch (std::exception& ex) { } catch (std::exception& ex) {
PADDLE_THROW(platform::errors::Fatal( Catch(ex);
"Unknown std::exception caught:\n%s.", ex.what()));
} catch (...) { } catch (...) {
PADDLE_THROW(platform::errors::Fatal("Unknown exception caught.")); LOG(FATAL) << "Unknown exception caught.";
} }
} }
...@@ -59,6 +62,15 @@ class ExceptionHolder { ...@@ -59,6 +62,15 @@ class ExceptionHolder {
auto e = *static_cast<platform::EOFException*>(exception_.get()); auto e = *static_cast<platform::EOFException*>(exception_.get());
throw e; throw e;
} }
case kBadAlloc: {
auto e = *static_cast<paddle::memory::allocation::BadAlloc*>(
exception_.get());
throw e;
}
case kBaseException: {
auto e = *static_cast<std::exception*>(exception_.get());
throw e;
}
} }
ClearImpl(); ClearImpl();
} }
...@@ -79,6 +91,12 @@ class ExceptionHolder { ...@@ -79,6 +91,12 @@ class ExceptionHolder {
case kEOF: { case kEOF: {
return "EOF"; return "EOF";
} }
case kBadAlloc: {
return "BadAlloc";
}
case kBaseException: {
return "BaseException";
}
} }
return "unknown"; return "unknown";
} }
...@@ -95,16 +113,39 @@ class ExceptionHolder { ...@@ -95,16 +113,39 @@ class ExceptionHolder {
type_ = kEnforceNotMet; type_ = kEnforceNotMet;
} }
void Catch(const memory::allocation::BadAlloc& exp) {
std::lock_guard<std::mutex> lock(mu_);
// BadAlloc have the highest priority
if (exception_.get() != nullptr) {
VLOG(2) << "exception is reset by BadAlloc, the original error message is"
<< exception_->what();
}
exception_.reset(new paddle::memory::allocation::BadAlloc(exp));
type_ = kBadAlloc;
}
void Catch(const platform::EOFException& exp) { void Catch(const platform::EOFException& exp) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
// EOFException will not cover up existing EnforceNotMet. // EOFException will not cover up existing EnforceNotMet.
if (exception_.get() == nullptr) { if (exception_.get() == nullptr) {
exception_.reset(new platform::EOFException(exp)); exception_.reset(new platform::EOFException(exp));
type_ = kEOF; type_ = kEOF;
} else {
VLOG(2) << "EOFException is skip, the error message of EOFException is "
<< exception_->what();
}
}
void Catch(const std::exception& exp) {
std::lock_guard<std::mutex> lock(mu_);
// std::exception will not cover anything
if (exception_.get() == nullptr) {
exception_.reset(new std::exception(exp));
type_ = kBaseException;
} }
} }
enum ExceptionType { kNone, kEnforceNotMet, kEOF }; enum ExceptionType { kNone, kEnforceNotMet, kEOF, kBadAlloc, kBaseException };
ExceptionType type_{kNone}; ExceptionType type_{kNone};
std::unique_ptr<std::exception> exception_; std::unique_ptr<std::exception> exception_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/exception_holder.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace framework {
namespace details {
namespace f = paddle::framework;
namespace p = paddle::platform;
TEST(ExceptionHolderTester, TestBadAllocCatch) {
ExceptionHolder exception_holder;
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
bool catch_bad_alloc = false;
try {
exception_holder.ReThrow();
} catch (memory::allocation::BadAlloc& ex) {
catch_bad_alloc = true;
} catch (...) {
catch_bad_alloc = false;
}
ASSERT_TRUE(catch_bad_alloc);
}
TEST(ExceptionHolderTester, TestBaseExpceptionCatch) {
ExceptionHolder exception_holder;
try {
throw std::exception();
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BaseException");
bool catch_base_exception = false;
try {
exception_holder.ReThrow();
} catch (std::exception& ex) {
catch_base_exception = true;
} catch (...) {
catch_base_exception = false;
}
ASSERT_TRUE(catch_base_exception);
}
TEST(ExceptionHolderTester, TestBadAllocCatchReplace) {
ExceptionHolder exception_holder;
try {
throw std::exception();
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BaseException");
try {
throw memory::allocation::BadAlloc("bad alloc test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_TRUE(exception_holder.IsCaught());
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
try {
throw platform::EOFException("eof test", "test_file", 0);
} catch (...) {
exception_holder.Catch(std::current_exception());
}
ASSERT_EQ(exception_holder.Type(), "BadAlloc");
}
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册