未验证 提交 af9066e8 编写于 作者: Z Zhou Wei 提交者: GitHub

[Custom OP]add PD_THROW and PD_CHECK for User Error message (#31253)

* [Custom OP]add PD_THROW and PD_CHECK for User error message

* PD_THROW and PD_CHECK, fix comment

* fix Windows error message

* fix Windows error message

* fix CI
上级 8c94d8cb
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "ext_dispatch.h" // NOLINT #include "ext_dispatch.h" // NOLINT
#include "ext_dtype.h" // NOLINT #include "ext_dtype.h" // NOLINT
#include "ext_exception.h" // NOLINT
#include "ext_op_meta_info.h" // NOLINT #include "ext_op_meta_info.h" // NOLINT
#include "ext_place.h" // NOLINT #include "ext_place.h" // NOLINT
#include "ext_tensor.h" // NOLINT #include "ext_tensor.h" // NOLINT
...@@ -14,7 +14,8 @@ limitations under the License. */ ...@@ -14,7 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "ext_dtype.h" // NOLINT #include "ext_dtype.h" // NOLINT
#include "ext_exception.h" // NOLINT
namespace paddle { namespace paddle {
...@@ -32,19 +33,18 @@ namespace paddle { ...@@ -32,19 +33,18 @@ namespace paddle {
///////// Floating Dispatch Marco /////////// ///////// Floating Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \ [&] { \
const auto& __dtype__ = TYPE; \ const auto& __dtype__ = TYPE; \
switch (__dtype__) { \ switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \ PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \ __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \ PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \ __VA_ARGS__) \
default: \ default: \
throw std::runtime_error("function " #NAME \ PD_THROW("function " #NAME " is not implemented for data type `", \
" not implemented for data type `" + \ ::paddle::ToString(__dtype__), "`"); \
::paddle::ToString(__dtype__) + "`"); \ } \
} \
}() }()
///////// Integral Dispatch Marco /////////// ///////// Integral Dispatch Marco ///////////
...@@ -63,9 +63,8 @@ namespace paddle { ...@@ -63,9 +63,8 @@ namespace paddle {
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \ __VA_ARGS__) \
default: \ default: \
throw std::runtime_error("function " #NAME \ PD_THROW("function " #NAME " is not implemented for data type `" + \
" not implemented for data type `" + \ ::paddle::ToString(__dtype__) + "`"); \
::paddle::ToString(__dtype__) + "`"); \
} \ } \
}() }()
...@@ -89,9 +88,8 @@ namespace paddle { ...@@ -89,9 +88,8 @@ namespace paddle {
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \ PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \ __VA_ARGS__) \
default: \ default: \
throw std::runtime_error("function " #NAME \ PD_THROW("function " #NAME " is not implemented for data type `" + \
" not implemented for data type `" + \ ::paddle::ToString(__dtype__) + "`"); \
::paddle::ToString(__dtype__) + "`"); \
} \ } \
}() }()
......
...@@ -14,9 +14,10 @@ limitations under the License. */ ...@@ -14,9 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <stdexcept>
#include <string> #include <string>
#include "ext_exception.h" // NOLINT
namespace paddle { namespace paddle {
enum class DataType { enum class DataType {
...@@ -50,7 +51,7 @@ inline std::string ToString(DataType dtype) { ...@@ -50,7 +51,7 @@ inline std::string ToString(DataType dtype) {
case DataType::FLOAT64: case DataType::FLOAT64:
return "double"; return "double";
default: default:
throw std::runtime_error("Unsupported paddle enum data type."); PD_THROW("Unsupported paddle enum data type.");
} }
} }
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <sstream>
#include <string>
namespace paddle {
//////////////// Exception handling and Error Message /////////////////
#if !defined(_WIN32)
#define PD_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
#define PD_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
#else
#define PD_UNLIKELY(expr) (expr)
#define PD_LIKELY(expr) (expr)
#endif
struct PD_Exception : public std::exception {
public:
template <typename... Args>
explicit PD_Exception(const std::string& msg, const char* file, int line,
const char* default_msg) {
std::ostringstream sout;
if (msg.empty()) {
sout << default_msg << "\n [" << file << ":" << line << "]";
} else {
sout << msg << "\n [" << file << ":" << line << "]";
}
err_msg_ = sout.str();
}
const char* what() const noexcept override { return err_msg_.c_str(); }
private:
std::string err_msg_;
};
class ErrorMessage {
public:
template <typename... Args>
explicit ErrorMessage(const Args&... args) {
build_string(args...);
}
void build_string() { oss << ""; }
template <typename T>
void build_string(const T& t) {
oss << t;
}
template <typename T, typename... Args>
void build_string(const T& t, const Args&... args) {
build_string(t);
build_string(args...);
}
std::string to_string() { return oss.str(); }
private:
std::ostringstream oss;
};
#if defined _WIN32
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR \
} \
catch (const std::exception& e) { \
std::cerr << e.what() << std::endl; \
throw e; \
}
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif
#define PD_CHECK(COND, ...) \
do { \
if (PD_UNLIKELY(!(COND))) { \
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
"Expected " #COND \
", but it's not satisfied."); \
} \
} while (0)
#define PD_THROW(...) \
do { \
auto __message__ = ::paddle::ErrorMessage(__VA_ARGS__).to_string(); \
throw ::paddle::PD_Exception(__message__, __FILE__, __LINE__, \
"An error occured."); \
} while (0)
} // namespace paddle
...@@ -21,8 +21,9 @@ limitations under the License. */ ...@@ -21,8 +21,9 @@ limitations under the License. */
#include <boost/any.hpp> #include <boost/any.hpp>
#include "ext_dll_decl.h" // NOLINT #include "ext_dll_decl.h" // NOLINT
#include "ext_tensor.h" // NOLINT #include "ext_exception.h" // NOLINT
#include "ext_tensor.h" // NOLINT
/** /**
* Op Meta Info Related Define. * Op Meta Info Related Define.
...@@ -47,26 +48,6 @@ using Tensor = paddle::Tensor; ...@@ -47,26 +48,6 @@ using Tensor = paddle::Tensor;
classname& operator=(const classname&) = delete; \ classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete classname& operator=(classname&&) = delete
#if defined _WIN32
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR \
} \
catch (const std::exception& e) { \
std::cerr << e.what() << std::endl; \
throw e; \
}
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif
#define PD_THROW(err_msg) \
do { \
HANDLE_THE_ERROR \
throw std::runtime_error(err_msg); \
END_HANDLE_THE_ERROR \
} while (0)
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \ struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
......
...@@ -23,6 +23,8 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) ...@@ -23,6 +23,8 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
if(NOT LINUX) if(NOT LINUX)
return() return()
endif() endif()
......
...@@ -79,7 +79,7 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) { ...@@ -79,7 +79,7 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
} else if (x.place() == paddle::PlaceType::kGPU) { } else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_forward(x); return relu_cuda_forward(x);
} else { } else {
throw std::runtime_error("Not implemented."); PD_THROW("Not implemented.");
} }
} }
...@@ -92,7 +92,7 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x, ...@@ -92,7 +92,7 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
} else if (x.place() == paddle::PlaceType::kGPU) { } else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward(x, out, grad_out); return relu_cuda_backward(x, out, grad_out);
} else { } else {
throw std::runtime_error("Not implemented."); PD_THROW("Not implemented.");
} }
} }
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/extension/include/ext_exception.h"
TEST(PD_THROW, empty) {
bool caught_exception = false;
try {
PD_THROW();
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("An error occured.") != std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc:20") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc:20") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
}
TEST(PD_THROW, non_empty) {
bool caught_exception = false;
try {
PD_THROW("PD_THROW returns ",
false,
". DataType of ",
1,
" is INT. ",
"DataType of ",
0.23,
" is FLOAT. ");
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("PD_THROW returns 0. DataType of 1 is INT. ") !=
std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
}
TEST(PD_CHECK, OK) {
PD_CHECK(true);
PD_CHECK(true, "PD_CHECK returns ", true, "now");
const size_t a = 1;
const size_t b = 10;
PD_CHECK(a < b);
PD_CHECK(a < b, "PD_CHECK returns ", true, a, "should < ", b);
}
TEST(PD_CHECK, FAILED) {
bool caught_exception = false;
try {
PD_CHECK(false);
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("Expected false, but it's not satisfied.") !=
std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PD_CHECK(false,
"PD_CHECK returns ",
false,
". DataType of ",
1,
" is INT. ",
"DataType of ",
0.23,
" is FLOAT. ");
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("PD_CHECK returns 0. DataType of 1 is INT. ") !=
std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
const size_t a = 1;
const size_t b = 10;
caught_exception = false;
try {
PD_CHECK(a > b);
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("Expected a > b, but it's not satisfied.") !=
std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
const size_t c = 123;
const float d = 0.345;
caught_exception = false;
try {
PD_CHECK(c < d, "PD_CHECK returns ", false, ", because ", c, " > ", d);
} catch (const std::exception& e) {
caught_exception = true;
std::string err_msg = e.what();
EXPECT_TRUE(err_msg.find("PD_CHECK returns 0, because 123 > 0.345") !=
std::string::npos);
#if _WIN32
EXPECT_TRUE(err_msg.find("tests\\custom_op\\test_check_error.cc") !=
std::string::npos);
#else
EXPECT_TRUE(
err_msg.find(
"python/paddle/fluid/tests/custom_op/test_check_error.cc") !=
std::string::npos);
#endif
}
EXPECT_TRUE(caught_exception);
}
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
import numpy as np import numpy as np
from paddle.utils.cpp_extension import load, get_build_directory from paddle.utils.cpp_extension import load, get_build_directory
from paddle.utils.cpp_extension.extension_utils import run_cmd from paddle.utils.cpp_extension.extension_utils import run_cmd
from utils import paddle_includes, extra_compile_args from utils import paddle_includes, extra_compile_args, IS_WINDOWS
from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static
# Because Windows don't use docker, the shared lib already exists in the # Because Windows don't use docker, the shared lib already exists in the
...@@ -84,6 +84,40 @@ class TestJITLoad(unittest.TestCase): ...@@ -84,6 +84,40 @@ class TestJITLoad(unittest.TestCase):
"custom op x grad: {},\n paddle api x grad: {}".format( "custom op x grad: {},\n paddle api x grad: {}".format(
x_grad, pd_x_grad)) x_grad, pd_x_grad))
def test_exception(self):
caught_exception = False
try:
x = np.random.uniform(-1, 1, [4, 8]).astype('int32')
custom_relu_dynamic(custom_module.custom_relu, 'cpu', 'float32', x)
except OSError as e:
caught_exception = True
self.assertTrue(
"function \"relu_cpu_forward\" is not implemented for data type `int32_t`"
in str(e))
if IS_WINDOWS:
self.assertTrue(
r"python\paddle\fluid\tests\custom_op\custom_relu_op.cc:48"
in str(e))
else:
self.assertTrue(
"python/paddle/fluid/tests/custom_op/custom_relu_op.cc:48"
in str(e))
self.assertTrue(caught_exception)
caught_exception = False
try:
x = np.random.uniform(-1, 1, [4, 8]).astype('int64')
custom_relu_dynamic(custom_module.custom_relu, 'gpu', 'float32', x)
except OSError as e:
caught_exception = True
self.assertTrue(
"function \"relu_cuda_forward_kernel\" is not implemented for data type `int64_t`"
in str(e))
self.assertTrue(
"python/paddle/fluid/tests/custom_op/custom_relu_op.cu:49" in
str(e))
self.assertTrue(caught_exception)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册