提交 f0a3fb6e 编写于 作者: Y Yu Yang

Using paddle::string in enforce

上级 ff000ae7
cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
......@@ -10,11 +10,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
namespace paddle {
namespace platform {
namespace framework {
/**
* @brief Enforce exception. Inherits std::exception
......@@ -23,10 +24,9 @@ namespace platform {
*/
class EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& msg, const char* file, int fileline)
: file_(file), fileline_(fileline) {
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
std::ostringstream sout;
sout << msg << " at [" << file_ << ":" << fileline_ << "];";
sout << msg << " at [" << file << ":" << fileline << "];";
all_msg_ = sout.str();
}
......@@ -34,52 +34,8 @@ class EnforceNotMet : public std::exception {
private:
std::string all_msg_;
const char* file_;
int fileline_;
};
namespace details {
inline void MakeStringInternal(std::ostringstream& stream) {}
template <typename T>
inline void MakeStringInternal(std::ostringstream& stream, T v) {
stream << v;
}
template <typename T, typename... ARGS>
inline void MakeStringInternal(std::ostringstream& stream, T v, ARGS... args) {
MakeStringInternal(stream, v);
MakeStringInternal(stream, args...);
};
/**
* @brief Make string will concat all args into a string.
*/
template <typename... ARGS>
inline std::string MakeString(ARGS... args) {
std::ostringstream sout;
details::MakeStringInternal(sout, args...);
return sout.str();
}
/**
* @brief special handle string
*/
template <>
inline std::string MakeString<std::string>(std::string str) {
return str;
}
/**
* @brief special handle const char*
*/
template <>
inline std::string MakeString<const char*>(const char* str) {
return std::string(str);
}
} // namespace details
// From https://stackoverflow.com/questions/30130930/
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
// should be true in most situation, it will make the compiler generate faster
......@@ -93,11 +49,10 @@ inline std::string MakeString<const char*>(const char* str) {
* This macro take __VA_ARGS__, user can pass any type if that type can
* serialize to std::ostream
*/
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
::paddle::platform::details::MakeString(__VA_ARGS__), __FILE__, \
__LINE__); \
#define PADDLE_THROW(...) \
do { \
throw ::paddle::framework::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/**
......@@ -110,5 +65,5 @@ inline std::string MakeString<const char*>(const char* str) {
} \
} while (0)
} // namespace platform
} // namespace framework
} // namespace paddle
......@@ -10,10 +10,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/platform/enforce.h>
#include <paddle/framework/enforce.h>
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok", 123, "now", 0.345);
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
size_t val = 1;
const size_t limit = 10;
PADDLE_ENFORCE(val < limit, "Enforce is OK too");
......@@ -22,8 +22,8 @@ TEST(ENFORCE, OK) {
TEST(ENFORCE, FAILED) {
bool in_catch = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok ", 123, " at all");
} catch (paddle::platform::EnforceNotMet err) {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::framework::EnforceNotMet err) {
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = err.what();
......@@ -31,6 +31,5 @@ TEST(ENFORCE, FAILED) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
\ No newline at end of file
......@@ -2,4 +2,3 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_test(enforce_test SRCS enforce_test.cc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册