提交 2f6e7a58 编写于 作者: G gangliao 提交者: GitHub

Merge pull request #3055 from reyoung/feature/unify_enforce_error_to_make_it_catchable

Make PADDLE_ENFORCE and PADDLE_THROW catchable
...@@ -69,7 +69,7 @@ TEST(OpKernel, all) { ...@@ -69,7 +69,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
TEST(AddBackwardOp, TestGradOp) { TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<PlainNet>();
......
...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { ...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false; bool caught = false;
try { try {
src_tensor.data<double>(); src_tensor.data<double>();
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataWith<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
......
...@@ -36,6 +36,21 @@ limitations under the License. */ ...@@ -36,6 +36,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
try {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
}
}
const char* what() const noexcept { return err_str_.c_str(); }
};
// Because most enforce conditions would evaluate to true, we can use // Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that // __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true. // always forces branch prediction of true.
...@@ -52,9 +67,7 @@ template <typename... Args> ...@@ -52,9 +67,7 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) { int stat, const Args&... args) {
if (UNLIKELY(!(stat))) { if (UNLIKELY(!(stat))) {
throw std::runtime_error( throw std::runtime_error(string::Sprintf(args...));
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
} }
...@@ -64,12 +77,8 @@ template <typename... Args> ...@@ -64,12 +77,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) { if (UNLIKELY(e)) {
// clang-format off throw thrust::system_error(e, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -77,12 +86,8 @@ template <typename... Args> ...@@ -77,12 +86,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) { if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) { if (stat == CUDNN_STATUS_SUCCESS) {
return; return;
} else { } else {
// clang-format off throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
throw std::runtime_error( string::Sprintf(args...));
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, "; err = "CUBLAS: license error, ";
} }
throw std::runtime_error(err + string::Sprintf(args...) + throw std::runtime_error(err + string::Sprintf(args...));
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw std::runtime_error( \ throw ::paddle::platform::EnforceNotMet( \
string::Sprintf(__VA_ARGS__) + \ std::make_exception_ptr( \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
do { \ do { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0) } while (0)
} // namespace platform } // namespace platform
......
...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) { ...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false; bool in_catch = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (const std::runtime_error& error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here // your error handling code here
in_catch = true; in_catch = true;
std::string msg = "Enforce is not ok 123 at all"; std::string msg = "Enforce is not ok 123 at all";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册