未验证 提交 ea7e5325 编写于 作者: G GaoWei8 提交者: GitHub

Refine PADDLE_ENFORCE (#25369)

* refine PADDLE_ENFORCE
test=develop
上级 ff7af219
...@@ -43,7 +43,9 @@ void SetNumThreads(int num_threads) { ...@@ -43,7 +43,9 @@ void SetNumThreads(int num_threads) {
platform::dynload::MKL_Set_Num_Threads(real_num_threads); platform::dynload::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads); omp_set_num_threads(real_num_threads);
#else #else
PADDLE_ENFORCE(false, "To be implemented."); PADDLE_THROW(platform::errors::Unimplemented(
"The library (except OPENBLAS, MKLML) is to be implemented, thus "
"number of threads can not be set."));
#endif #endif
} }
......
...@@ -26,13 +26,13 @@ void CudaProfilerInit(std::string output_file, std::string output_mode, ...@@ -26,13 +26,13 @@ void CudaProfilerInit(std::string output_file, std::string output_mode,
std::string config_file) { std::string config_file) {
PADDLE_ENFORCE(output_mode == "kvp" || output_mode == "csv"); PADDLE_ENFORCE(output_mode == "kvp" || output_mode == "csv");
cudaOutputMode_t mode = output_mode == "csv" ? cudaCSV : cudaKeyValuePair; cudaOutputMode_t mode = output_mode == "csv" ? cudaCSV : cudaKeyValuePair;
PADDLE_ENFORCE( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaProfilerInitialize(config_file.c_str(), output_file.c_str(), mode)); cudaProfilerInitialize(config_file.c_str(), output_file.c_str(), mode));
} }
void CudaProfilerStart() { PADDLE_ENFORCE(cudaProfilerStart()); } void CudaProfilerStart() { PADDLE_ENFORCE_CUDA_SUCCESS(cudaProfilerStart()); }
void CudaProfilerStop() { PADDLE_ENFORCE(cudaProfilerStop()); } void CudaProfilerStop() { PADDLE_ENFORCE_CUDA_SUCCESS(cudaProfilerStop()); }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -103,7 +103,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { ...@@ -103,7 +103,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
case PoolingMode::kMaximum: case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX; return CUDNN_POOLING_MAX;
default: default:
PADDLE_THROW("Unexpected pooling mode."); PADDLE_THROW(
platform::errors::Unimplemented("Unexpected CUDNN pooling mode."));
} }
} }
#else #else
...@@ -119,7 +120,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) { ...@@ -119,7 +120,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
case PoolingMode::kMaximum: case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX; return CUDNN_POOLING_MAX;
default: default:
PADDLE_THROW("Unexpected pooling mode."); PADDLE_THROW(
platform::errors::Unimplemented("Unexpected CUDNN pooling mode."));
} }
} }
#endif // CUDNN_VERSION < 6000 #endif // CUDNN_VERSION < 6000
...@@ -140,7 +142,8 @@ inline ActivationMode StringToActivationMode(const std::string& str) { ...@@ -140,7 +142,8 @@ inline ActivationMode StringToActivationMode(const std::string& str) {
} else if (str == "bandpass") { } else if (str == "bandpass") {
return ActivationMode::kBandPass; return ActivationMode::kBandPass;
} else { } else {
PADDLE_THROW("Unknown activation string: %s", str); PADDLE_THROW(
platform::errors::Unimplemented("Unknown activation string: %s.", str));
} }
} }
...@@ -208,7 +211,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat( ...@@ -208,7 +211,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat(
case DataLayout::kNDHWC: case DataLayout::kNDHWC:
return CUDNN_TENSOR_NHWC; // add, liyamei return CUDNN_TENSOR_NHWC; // add, liyamei
default: default:
PADDLE_THROW("Unknown cudnn equivalent for order"); PADDLE_THROW(platform::errors::Unimplemented(
"CUDNN has no equivalent dataLayout for input order."));
} }
return CUDNN_TENSOR_NCHW; return CUDNN_TENSOR_NCHW;
} }
...@@ -329,18 +333,28 @@ class ScopedConvolutionDescriptor { ...@@ -329,18 +333,28 @@ class ScopedConvolutionDescriptor {
inline cudnnConvolutionDescriptor_t descriptor( inline cudnnConvolutionDescriptor_t descriptor(
cudnnDataType_t type, const std::vector<int>& pads, cudnnDataType_t type, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations) { const std::vector<int>& strides, const std::vector<int>& dilations) {
PADDLE_ENFORCE_EQ(pads.size(), strides.size()); PADDLE_ENFORCE_EQ(pads.size(), strides.size(),
PADDLE_ENFORCE_EQ(pads.size(), dilations.size()); platform::errors::InvalidArgument(
"The size of pads and strides should be equal. But "
"received size of pads is %d, size of strides is %d.",
pads.size(), strides.size()));
PADDLE_ENFORCE_EQ(
pads.size(), dilations.size(),
platform::errors::InvalidArgument(
"The size of pads and dilations should be equal. But received size "
"of pads is %d, size of dilations is %d.",
pads.size(), dilations.size()));
#if !CUDNN_VERSION_MIN(6, 0, 0) #if !CUDNN_VERSION_MIN(6, 0, 0)
// cudnn v5 does not support dilation conv, the argument is called upscale // cudnn v5 does not support dilation conv, the argument is called upscale
// instead of dilations and it is must be one. // instead of dilations and it is must be one.
for (size_t i = 0; i < dilations.size(); ++i) { for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(dilations[i], 1,
dilations[i], 1, platform::errors::InvalidArgument(
"Dilations conv is not supported in this cuDNN version(%d.%d.%d).", "Dilations conv is not supported in this cuDNN "
"version(%d.%d.%d).",
CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100, CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100,
CUDNN_VERSION % 100); CUDNN_VERSION % 100));
} }
#endif #endif
...@@ -377,8 +391,17 @@ class ScopedPoolingDescriptor { ...@@ -377,8 +391,17 @@ class ScopedPoolingDescriptor {
const std::vector<int>& kernel, const std::vector<int>& kernel,
const std::vector<int>& pads, const std::vector<int>& pads,
const std::vector<int>& strides) { const std::vector<int>& strides) {
PADDLE_ENFORCE_EQ(kernel.size(), pads.size()); PADDLE_ENFORCE_EQ(kernel.size(), pads.size(),
PADDLE_ENFORCE_EQ(kernel.size(), strides.size()); platform::errors::InvalidArgument(
"The size of kernel and pads should be equal. But "
"received size of kernel is %d, size of pads is %d.",
kernel.size(), pads.size()));
PADDLE_ENFORCE_EQ(
kernel.size(), strides.size(),
platform::errors::InvalidArgument(
"The size of kernel and strides should be equal. But "
"received size of kernel is %d, size of strides is %d.",
kernel.size(), strides.size()));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetPoolingNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetPoolingNdDescriptor(
desc_, (GetPoolingMode(mode)), desc_, (GetPoolingMode(mode)),
CUDNN_PROPAGATE_NAN, // Always propagate nans. CUDNN_PROPAGATE_NAN, // Always propagate nans.
...@@ -456,8 +479,9 @@ class ScopedActivationDescriptor { ...@@ -456,8 +479,9 @@ class ScopedActivationDescriptor {
mode = CUDNN_ACTIVATION_TANH; mode = CUDNN_ACTIVATION_TANH;
break; break;
default: default:
PADDLE_THROW("unrecognized activation mode: %d .", PADDLE_THROW(platform::errors::Unimplemented(
static_cast<int>(activation_mode)); "Unrecognized CUDNN activation mode: %d.",
static_cast<int>(activation_mode)));
} }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetActivationDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetActivationDescriptor(
desc_, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling)); desc_, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling));
......
...@@ -59,12 +59,11 @@ DeviceContextPool* DeviceContextPool::pool = nullptr; ...@@ -59,12 +59,11 @@ DeviceContextPool* DeviceContextPool::pool = nullptr;
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
PADDLE_THROW( PADDLE_THROW(platform::errors::Unimplemented(
"Place %s is not supported, Please check that your paddle compiles " "Place %s is not supported. Please check that your paddle compiles "
"with WITH_GPU " "with WITH_GPU option or check that your train process hold the "
"option or check that your train process hold the correct gpu_id if " "correct gpu_id if you use Executor.",
"you use Executor", place));
place);
} }
return it->second.get().get(); return it->second.get().get();
} }
...@@ -84,7 +83,11 @@ inline void EmplaceDeviceContext( ...@@ -84,7 +83,11 @@ inline void EmplaceDeviceContext(
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(
places.size(), 0,
platform::errors::InvalidArgument("The number of platform places should "
"be larger than 0. But received %d.",
places.size()));
std::set<Place> set; std::set<Place> set;
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
...@@ -100,18 +103,18 @@ DeviceContextPool::DeviceContextPool( ...@@ -100,18 +103,18 @@ DeviceContextPool::DeviceContextPool(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p); EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::Unimplemented(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace is not supported. Please re-compile with WITH_GPU."
"option"); "option"));
#endif #endif
} else if (platform::is_cuda_pinned_place(p)) { } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>( EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
&device_contexts_, p); &device_contexts_, p);
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::Unimplemented(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported. Please re-compile with WITH_GPU."
"option"); "option"));
#endif #endif
} }
} }
......
...@@ -575,7 +575,8 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -575,7 +575,8 @@ class DeviceTracerImpl : public DeviceTracer {
} else if (platform::is_cuda_pinned_place(r.place)) { } else if (platform::is_cuda_pinned_place(r.place)) {
event->set_place(proto::MemEvent::CUDAPinnedPlace); event->set_place(proto::MemEvent::CUDAPinnedPlace);
} else { } else {
PADDLE_THROW("The current place is not supported."); PADDLE_THROW(platform::errors::Unimplemented(
"The current place is not supported."));
} }
event->set_alloc_in(r.alloc_in); event->set_alloc_in(r.alloc_in);
event->set_free_in(r.free_in); event->set_free_in(r.free_in);
......
...@@ -319,9 +319,11 @@ void* GetMKLMLDsoHandle() { ...@@ -319,9 +319,11 @@ void* GetMKLMLDsoHandle() {
void* GetOpDsoHandle(const std::string& dso_name) { void* GetOpDsoHandle(const std::string& dso_name) {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
PADDLE_THROW("Do not support Apple."); PADDLE_THROW(platform::errors::Unimplemented(
"Create custom cpp op outside framework do not support Apple."));
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) #elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
PADDLE_THROW("Do not support Windows."); PADDLE_THROW(platform::errors::Unimplemented(
"Create custom cpp op outside framework do not support Windows."));
#else #else
return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name); return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name);
#endif #endif
......
...@@ -20,37 +20,43 @@ limitations under the License. */ ...@@ -20,37 +20,43 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE is ok %d now %f.", 123, 0.345));
size_t val = 1; size_t val = 1;
const size_t limit = 10; const size_t limit = 10;
PADDLE_ENFORCE(val < limit, "Enforce is OK too"); PADDLE_ENFORCE(val < limit, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE tests failed."));
} }
TEST(ENFORCE, FAILED) { TEST(ENFORCE, FAILED) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work %d at all.", 123));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("Enforce is not ok 123 at all") != EXPECT_TRUE(ex_msg.find("PADDLE_ENFORCE won't work 123 at all.") !=
std::string::npos); std::string::npos);
} }
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
caught_exception = false; caught_exception = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok at all"); PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work at all."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("Enforce is not ok at all") != std::string::npos); EXPECT_TRUE(ex_msg.find("PADDLE_ENFORCE won't work at all.") !=
std::string::npos);
} }
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
caught_exception = false; caught_exception = false;
try { try {
PADDLE_ENFORCE(false); PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work at all."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
EXPECT_NE(std::string(error.what()).find(" at "), 0UL); EXPECT_NE(std::string(error.what()).find(" at "), 0UL);
...@@ -61,9 +67,11 @@ TEST(ENFORCE, FAILED) { ...@@ -61,9 +67,11 @@ TEST(ENFORCE, FAILED) {
TEST(ENFORCE, NO_ARG_OK) { TEST(ENFORCE, NO_ARG_OK) {
int a = 2; int a = 2;
int b = 2; int b = 2;
PADDLE_ENFORCE_EQ(a, b); PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
// test enforce with extra message. // test enforce with extra message.
PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info"); PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"Some %s wrong in PADDLE_ENFORCE_EQ.", "info"));
} }
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
...@@ -71,7 +79,7 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { ...@@ -71,7 +79,7 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument(
"the result is not equal correct result.")); "The result is not equal correct result."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -86,7 +94,7 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { ...@@ -86,7 +94,7 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument(
"the result is not equal correct result.")); "The result is not equal correct result."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -98,15 +106,19 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { ...@@ -98,15 +106,19 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
} }
TEST(ENFORCE_NE, OK) { TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1, 2); PADDLE_ENFORCE_NE(1, 2, paddle::platform::errors::Unavailable(
PADDLE_ENFORCE_NE(1.0, 2UL); "PADDLE_ENFORCE_NE tests failed."));
PADDLE_ENFORCE_NE(1.0, 2UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_NE tests failed."));
} }
TEST(ENFORCE_NE, FAIL) { TEST(ENFORCE_NE, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
// 2UL here to check data type compatible // 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL); PADDLE_ENFORCE_NE(1.0, 1UL,
paddle::platform::errors::Unavailable(
"Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -116,11 +128,15 @@ TEST(ENFORCE_NE, FAIL) { ...@@ -116,11 +128,15 @@ TEST(ENFORCE_NE, FAIL) {
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
} }
TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); } TEST(ENFORCE_GT, OK) {
PADDLE_ENFORCE_GT(2, 1, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GT tests failed."));
}
TEST(ENFORCE_GT, FAIL) { TEST(ENFORCE_GT, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GT(1, 2); PADDLE_ENFORCE_GT(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 > 2, but received 1:1 <= 2:2."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -131,14 +147,18 @@ TEST(ENFORCE_GT, FAIL) { ...@@ -131,14 +147,18 @@ TEST(ENFORCE_GT, FAIL) {
} }
TEST(ENFORCE_GE, OK) { TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(2, 2); PADDLE_ENFORCE_GE(2, 2, paddle::platform::errors::Unavailable(
PADDLE_ENFORCE_GE(3, 2); "PADDLE_ENFORCE_GE tests failed."));
PADDLE_ENFORCE_GE(3.21, 2.0); PADDLE_ENFORCE_GE(3, 2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GE tests failed."));
PADDLE_ENFORCE_GE(3.21, 2.0, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GE tests failed."));
} }
TEST(ENFORCE_GE, FAIL) { TEST(ENFORCE_GE, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GE(1, 2); PADDLE_ENFORCE_GE(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 >= 2, but received 1:1 < 2:2."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -149,16 +169,22 @@ TEST(ENFORCE_GE, FAIL) { ...@@ -149,16 +169,22 @@ TEST(ENFORCE_GE, FAIL) {
} }
TEST(ENFORCE_LE, OK) { TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(1, 1); PADDLE_ENFORCE_LE(1, 1, paddle::platform::errors::Unavailable(
PADDLE_ENFORCE_LE(1UL, 1UL); "PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2, 3); PADDLE_ENFORCE_LE(1UL, 1UL, paddle::platform::errors::Unavailable(
PADDLE_ENFORCE_LE(2UL, 3UL); "PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2.0, 3.2); PADDLE_ENFORCE_LE(2, 3, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2UL, 3UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2.0, 3.2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
} }
TEST(ENFORCE_LE, FAIL) { TEST(ENFORCE_LE, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_GT(1, 2); PADDLE_ENFORCE_GT(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 > 2, but received 1:1 <= 2:2."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -169,14 +195,20 @@ TEST(ENFORCE_LE, FAIL) { ...@@ -169,14 +195,20 @@ TEST(ENFORCE_LE, FAIL) {
} }
TEST(ENFORCE_LT, OK) { TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(3, 10); PADDLE_ENFORCE_LT(3, 10, paddle::platform::errors::Unavailable(
PADDLE_ENFORCE_LT(2UL, 3UL); "PADDLE_ENFORCE_LT tests failed."));
PADDLE_ENFORCE_LT(2, 3); PADDLE_ENFORCE_LT(2UL, 3UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LT tests failed."));
PADDLE_ENFORCE_LT(2, 3, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LT tests failed."));
} }
TEST(ENFORCE_LT, FAIL) { TEST(ENFORCE_LT, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_LT(1UL, 0.12); PADDLE_ENFORCE_LT(
1UL, 0.12,
paddle::platform::errors::InvalidArgument(
"Expected 1UL < 0.12, but received 1UL:1 >= 0.12:0.12."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
...@@ -189,18 +221,20 @@ TEST(ENFORCE_LT, FAIL) { ...@@ -189,18 +221,20 @@ TEST(ENFORCE_LT, FAIL) {
TEST(ENFORCE_NOT_NULL, OK) { TEST(ENFORCE_NOT_NULL, OK) {
int* a = new int; int* a = new int;
PADDLE_ENFORCE_NOT_NULL(a); PADDLE_ENFORCE_NOT_NULL(a, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_NOT_NULL tests failed."));
delete a; delete a;
} }
TEST(ENFORCE_NOT_NULL, FAIL) { TEST(ENFORCE_NOT_NULL, FAIL) {
bool caught_exception = false; bool caught_exception = false;
try { try {
int* a = nullptr; int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a); PADDLE_ENFORCE_NOT_NULL(
a, paddle::platform::errors::Unavailable("The a should not be null."));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true; caught_exception = true;
std::string ex_msg = error.what(); std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("a should not be null") != std::string::npos); EXPECT_TRUE(ex_msg.find("The a should not be null.") != std::string::npos);
} }
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
} }
...@@ -233,14 +267,16 @@ std::ostream& operator<<(std::ostream& os, const Dims& d) { ...@@ -233,14 +267,16 @@ std::ostream& operator<<(std::ostream& os, const Dims& d) {
TEST(ENFORCE_USER_DEFINED_CLASS, EQ) { TEST(ENFORCE_USER_DEFINED_CLASS, EQ) {
Dims a{{1, 2, 3, 4}}, b{{1, 2, 3, 4}}; Dims a{{1, 2, 3, 4}}, b{{1, 2, 3, 4}};
PADDLE_ENFORCE_EQ(a, b); PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
} }
TEST(ENFORCE_USER_DEFINED_CLASS, NE) { TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}}; Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
bool caught_exception = false; bool caught_exception = false;
try { try {
PADDLE_ENFORCE_EQ(a, b); PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
} catch (paddle::platform::EnforceNotMet&) { } catch (paddle::platform::EnforceNotMet&) {
caught_exception = true; caught_exception = true;
} }
...@@ -329,12 +365,15 @@ TEST(enforce, cannot_to_string_type) { ...@@ -329,12 +365,15 @@ TEST(enforce, cannot_to_string_type) {
"int can be converted to string"); "int can be converted to string");
CannotToStringType obj1(3), obj2(4), obj3(3); CannotToStringType obj1(3), obj2(4), obj3(3);
PADDLE_ENFORCE_NE(obj1, obj2, "Object 1 is not equal to Object 2"); PADDLE_ENFORCE_NE(obj1, obj2, paddle::platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(obj1, obj3, "Object 1 is equal to Object 3"); "Object 1 is not equal to Object 2"));
PADDLE_ENFORCE_EQ(obj1, obj3, paddle::platform::errors::InvalidArgument(
"Object 1 is equal to Object 3"));
std::string msg = "Compare obj1 with obj2"; std::string msg = "Compare obj1 with obj2";
try { try {
PADDLE_ENFORCE_EQ(obj1, obj2, msg); PADDLE_ENFORCE_EQ(obj1, obj2,
paddle::platform::errors::InvalidArgument(msg));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what(); std::string ex_msg = error.what();
LOG(INFO) << ex_msg; LOG(INFO) << ex_msg;
...@@ -347,7 +386,7 @@ TEST(enforce, cannot_to_string_type) { ...@@ -347,7 +386,7 @@ TEST(enforce, cannot_to_string_type) {
msg = "Compare x with y"; msg = "Compare x with y";
try { try {
int x = 3, y = 2; int x = 3, y = 2;
PADDLE_ENFORCE_EQ(x, y, msg); PADDLE_ENFORCE_EQ(x, y, paddle::platform::errors::InvalidArgument(msg));
} catch (paddle::platform::EnforceNotMet& error) { } catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what(); std::string ex_msg = error.what();
LOG(INFO) << ex_msg; LOG(INFO) << ex_msg;
...@@ -357,14 +396,22 @@ TEST(enforce, cannot_to_string_type) { ...@@ -357,14 +396,22 @@ TEST(enforce, cannot_to_string_type) {
} }
std::set<int> set; std::set<int> set;
PADDLE_ENFORCE_EQ(set.begin(), set.end()); PADDLE_ENFORCE_EQ(set.begin(), set.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of set is not equal."));
set.insert(3); set.insert(3);
PADDLE_ENFORCE_NE(set.begin(), set.end()); PADDLE_ENFORCE_NE(set.begin(), set.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of set is equal."));
std::list<float> list; std::list<float> list;
PADDLE_ENFORCE_EQ(list.begin(), list.end()); PADDLE_ENFORCE_EQ(list.begin(), list.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of list is not equal."));
list.push_back(4); list.push_back(4);
PADDLE_ENFORCE_NE(list.begin(), list.end()); PADDLE_ENFORCE_NE(list.begin(), list.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of list is equal."));
} }
TEST(GET_DATA_SAFELY_MACRO, SUCCESS) { TEST(GET_DATA_SAFELY_MACRO, SUCCESS) {
......
...@@ -145,7 +145,9 @@ TEST(float16, lod_tensor_cpu) { ...@@ -145,7 +145,9 @@ TEST(float16, lod_tensor_cpu) {
TEST(float16, floating) { TEST(float16, floating) {
// compile time assert. // compile time assert.
PADDLE_ENFORCE_EQ(std::is_floating_point<float16>::value, true); PADDLE_ENFORCE_EQ(
std::is_floating_point<float16>::value, true,
platform::errors::Unavailable("The float16 support in CPU failed."));
} }
TEST(float16, print) { TEST(float16, print) {
......
...@@ -261,8 +261,12 @@ TEST(float16, typeid) { ...@@ -261,8 +261,12 @@ TEST(float16, typeid) {
int b(0); int b(0);
// compile time assert // compile time assert
PADDLE_ENFORCE_EQ(functor(a), true); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(functor2(b), false); functor(a), true,
platform::errors::Unavailable("The float16 support in GPU failed."));
PADDLE_ENFORCE_EQ(
functor2(b), false,
platform::errors::Unavailable("The float16 support in GPU failed."));
} }
// GPU test // GPU test
......
...@@ -243,7 +243,9 @@ size_t GpuMaxAllocSize() { ...@@ -243,7 +243,9 @@ size_t GpuMaxAllocSize() {
static size_t GpuAllocSize(bool realloc) { static size_t GpuAllocSize(bool realloc) {
size_t available_to_alloc = GpuAvailableMemToAlloc(); size_t available_to_alloc = GpuAvailableMemToAlloc();
PADDLE_ENFORCE_GT(available_to_alloc, 0, "No enough available GPU memory"); PADDLE_ENFORCE_GT(
available_to_alloc, 0,
platform::errors::ResourceExhausted("Not enough available GPU memory."));
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// allocated by fraction // allocated by fraction
size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
...@@ -251,8 +253,9 @@ static size_t GpuAllocSize(bool realloc) { ...@@ -251,8 +253,9 @@ static size_t GpuAllocSize(bool realloc) {
size_t alloc_bytes = size_t alloc_bytes =
(flag_mb > 0ul ? flag_mb << 20 : available_to_alloc * (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
FLAGS_fraction_of_gpu_memory_to_use); FLAGS_fraction_of_gpu_memory_to_use);
PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes, PADDLE_ENFORCE_GE(
"No enough available GPU memory"); available_to_alloc, alloc_bytes,
platform::errors::ResourceExhausted("Not enough available GPU memory."));
VLOG(10) << "Alloc size is " << (alloc_bytes >> 20) VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
<< " MiB, is it Re-alloc: " << realloc; << " MiB, is it Re-alloc: " << realloc;
return alloc_bytes; return alloc_bytes;
......
...@@ -98,9 +98,8 @@ void InitP2P(std::vector<int> devices) { ...@@ -98,9 +98,8 @@ void InitP2P(std::vector<int> devices) {
for (int j = 0; j < count; ++j) { for (int j = 0; j < count; ++j) {
if (devices[i] == devices[j]) continue; if (devices[i] == devices[j]) continue;
int can_acess = -1; int can_acess = -1;
PADDLE_ENFORCE( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]), cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]));
"Failed to test P2P access.");
if (can_acess != 1) { if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << devices[i] LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j]; << " to " << devices[j];
......
...@@ -172,7 +172,9 @@ class MKLDNNHandlerT { ...@@ -172,7 +172,9 @@ class MKLDNNHandlerT {
const std::string key_fwd_pd = key_common_ + "@forward_pd"; const std::string key_fwd_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd)); dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(fwd_pd_); PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@backward_pd"; const std::string key_pd = key_ + "@backward_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>( bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
...@@ -1450,8 +1452,10 @@ static void SetDstMemoryQuantized( ...@@ -1450,8 +1452,10 @@ static void SetDstMemoryQuantized(
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size(); const size_t dst_dims = dst_tz.size();
MKLDNNMemoryFormat dst_fmt; MKLDNNMemoryFormat dst_fmt;
PADDLE_ENFORCE_LE(dst_dims, 5, PADDLE_ENFORCE_LE(dst_dims, 5, platform::errors::InvalidArgument(
"Dst memory for quantization can not have dims > 5"); "Dst memory for quantization can not have "
"dims > 5. But received dst_dims is %d.",
dst_dims));
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, output_format); dst_fmt = platform::MKLDNNFormatForSize(dst_dims, output_format);
auto tmp_dst_md = platform::MKLDNNMemDesc( auto tmp_dst_md = platform::MKLDNNMemDesc(
......
...@@ -46,7 +46,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { ...@@ -46,7 +46,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
} else if (type == framework::proto::VarType::FP16) { } else if (type == framework::proto::VarType::FP16) {
return ncclFloat16; return ncclFloat16;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
} }
} }
...@@ -95,7 +96,8 @@ struct NCCLContextMap { ...@@ -95,7 +96,8 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places, explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr, ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 1, size_t trainer_id = 0) { size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE_EQ(!places.empty(), true); PADDLE_ENFORCE_EQ(!places.empty(), true, platform::errors::InvalidArgument(
"The NCCL place is empty."));
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
int dev_id = BOOST_GET_CONST(CUDAPlace, p).device; int dev_id = BOOST_GET_CONST(CUDAPlace, p).device;
...@@ -104,7 +106,8 @@ struct NCCLContextMap { ...@@ -104,7 +106,8 @@ struct NCCLContextMap {
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
order_.size(), contexts_.size(), order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device"); platform::errors::Unavailable("NCCL Context Map does not support "
"contain two or more same device."));
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]); std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms. // if num_trainers == 1, should create a new nccl id for local comms.
...@@ -113,7 +116,8 @@ struct NCCLContextMap { ...@@ -113,7 +116,8 @@ struct NCCLContextMap {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
} else { } else {
PADDLE_ENFORCE_NOT_NULL(nccl_id); PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument(
"The NCCL id should not be null."));
{ {
int nranks = num_trainers * order_.size(); int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad; NCCLGroupGuard gurad;
...@@ -263,13 +267,17 @@ class NCCLCommunicator { ...@@ -263,13 +267,17 @@ class NCCLCommunicator {
size_t trainers_num, size_t trainer_id, size_t trainers_num, size_t trainer_id,
size_t inter_trainers_num, size_t inter_trainers_num,
size_t exter_trainers_num) { size_t exter_trainers_num) {
PADDLE_ENFORCE_EQ(trainers_num, inter_trainers_num * exter_trainers_num, PADDLE_ENFORCE_EQ(
trainers_num, inter_trainers_num * exter_trainers_num,
platform::errors::InvalidArgument(
"trainers_num:%llu != inter_trainers_num:%llu * " "trainers_num:%llu != inter_trainers_num:%llu * "
"exter_trainers_num:%llu", "exter_trainers_num:%llu",
trainers_num, inter_trainers_num, exter_trainers_num); trainers_num, inter_trainers_num, exter_trainers_num));
PADDLE_ENFORCE_GT(inter_trainers_num, 1, "inter_trainers_num:%llu must > 1", PADDLE_ENFORCE_GT(
inter_trainers_num); inter_trainers_num, 1,
platform::errors::InvalidArgument("inter_trainers_num:%llu must > 1",
inter_trainers_num));
int inter_trainer_id = trainer_id % inter_trainers_num; int inter_trainer_id = trainer_id % inter_trainers_num;
for (size_t i = 0; i < inter_nccl_ids.size(); i++) { for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
...@@ -300,14 +308,16 @@ class NCCLCommunicator { ...@@ -300,14 +308,16 @@ class NCCLCommunicator {
bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; } bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const { NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
PADDLE_ENFORCE(h_inter_ctxs_.size() > 0, PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
"must init hierarchical ctxs first!"); platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get(); return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
} }
NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const { NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
PADDLE_ENFORCE(h_exter_ctxs_.size() > 0, PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
"must init hierarchical ctxs first!"); platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get(); return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册