diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 7abe2ab89e0798672149e28a8d02f7a58b6de3ea..8435410564b2770b89fa28dbf96c9421335ca889 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -3,3 +3,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(must_check_test SRCS must_check_test.cc) +cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..e501e80c5579bf51f8a423886c75e238cafb29c1 --- /dev/null +++ b/paddle/platform/enforce.h @@ -0,0 +1,116 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#pragma once +#include +#include + +namespace paddle { +namespace platform { + +/** + * @brief Enforce exception. Inherits std::exception + * + * All enforce condition not met, will throw an EnforceNotMet exception. + */ +class EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& msg, const char* file, int fileline) + : file_(file), fileline_(fileline) { + std::ostringstream sout; + sout << msg << " at [" << file_ << ":" << fileline_ << "];"; + all_msg_ = sout.str(); + } + + const char* what() const noexcept override { return all_msg_.c_str(); } + + private: + std::string all_msg_; + const char* file_; + int fileline_; +}; + +namespace details { + +inline void MakeStringInternal(std::ostringstream& stream) {} + +template +inline void MakeStringInternal(std::ostringstream& stream, T v) { + stream << v; +} + +template +inline void MakeStringInternal(std::ostringstream& stream, T v, ARGS... args) { + MakeStringInternal(stream, v); + MakeStringInternal(stream, args...); +}; + +/** + * @brief Make string will concat all args into a string. + */ +template +inline std::string MakeString(ARGS... args) { + std::ostringstream sout; + details::MakeStringInternal(sout, args...); + return sout.str(); +} + +/** + * @brief special handle string + */ +template <> +inline std::string MakeString(std::string str) { + return str; +} + +/** + * @brief special handle const char* + */ +template <> +inline std::string MakeString(const char* str) { + return std::string(str); +} +} // namespace details + +// From https://stackoverflow.com/questions/30130930/ +// __buildin_expect is in C++ 11 standard. Since the condition which enforced +// should be true in most situation, it will make the compiler generate faster +// code by adding `UNLIKELY` macro. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +/** + * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & + * __LINE__ + * + * This macro take __VA_ARGS__, user can pass any type if that type can + * serialize to std::ostream + */ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + ::paddle::platform::details::MakeString(__VA_ARGS__), __FILE__, \ + __LINE__); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + if (UNLIKELY(!(condition))) { \ + PADDLE_THROW(__VA_ARGS__); \ + } \ + } while (0) + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d3e945e972e643f61fd96a963b82a195358d2f75 --- /dev/null +++ b/paddle/platform/enforce_test.cc @@ -0,0 +1,25 @@ +#include +#include + +TEST(ENFORCE, OK) { + PADDLE_ENFORCE(true, "Enforce is ok", 123, "now", 0.345); + size_t val = 1; + const size_t limit = 10; + PADDLE_ENFORCE(val < limit, "Enforce is OK too"); +} + +TEST(ENFORCE, FAILED) { + bool in_catch = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok ", 123, " at all"); + } catch (paddle::platform::EnforceNotMet err) { + in_catch = true; + std::string msg = "Enforce is not ok 123 at all"; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} \ No newline at end of file