未验证 提交 ea7e5325 编写于 作者: G GaoWei8 提交者: GitHub

Refine PADDLE_ENFORCE (#25369)

* refine PADDLE_ENFORCE
test=develop
上级 ff7af219
......@@ -43,7 +43,9 @@ void SetNumThreads(int num_threads) {
platform::dynload::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads);
#else
PADDLE_ENFORCE(false, "To be implemented.");
PADDLE_THROW(platform::errors::Unimplemented(
"The library (except OPENBLAS, MKLML) is to be implemented, thus "
"number of threads can not be set."));
#endif
}
......
......@@ -26,13 +26,13 @@ void CudaProfilerInit(std::string output_file, std::string output_mode,
std::string config_file) {
PADDLE_ENFORCE(output_mode == "kvp" || output_mode == "csv");
cudaOutputMode_t mode = output_mode == "csv" ? cudaCSV : cudaKeyValuePair;
PADDLE_ENFORCE(
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaProfilerInitialize(config_file.c_str(), output_file.c_str(), mode));
}
void CudaProfilerStart() { PADDLE_ENFORCE(cudaProfilerStart()); }
void CudaProfilerStart() { PADDLE_ENFORCE_CUDA_SUCCESS(cudaProfilerStart()); }
void CudaProfilerStop() { PADDLE_ENFORCE(cudaProfilerStop()); }
void CudaProfilerStop() { PADDLE_ENFORCE_CUDA_SUCCESS(cudaProfilerStop()); }
} // namespace platform
} // namespace paddle
......@@ -103,7 +103,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX;
default:
PADDLE_THROW("Unexpected pooling mode.");
PADDLE_THROW(
platform::errors::Unimplemented("Unexpected CUDNN pooling mode."));
}
}
#else
......@@ -119,7 +120,8 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX;
default:
PADDLE_THROW("Unexpected pooling mode.");
PADDLE_THROW(
platform::errors::Unimplemented("Unexpected CUDNN pooling mode."));
}
}
#endif // CUDNN_VERSION < 6000
......@@ -140,7 +142,8 @@ inline ActivationMode StringToActivationMode(const std::string& str) {
} else if (str == "bandpass") {
return ActivationMode::kBandPass;
} else {
PADDLE_THROW("Unknown activation string: %s", str);
PADDLE_THROW(
platform::errors::Unimplemented("Unknown activation string: %s.", str));
}
}
......@@ -208,7 +211,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat(
case DataLayout::kNDHWC:
return CUDNN_TENSOR_NHWC; // add, liyamei
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
PADDLE_THROW(platform::errors::Unimplemented(
"CUDNN has no equivalent dataLayout for input order."));
}
return CUDNN_TENSOR_NCHW;
}
......@@ -329,18 +333,28 @@ class ScopedConvolutionDescriptor {
inline cudnnConvolutionDescriptor_t descriptor(
cudnnDataType_t type, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations) {
PADDLE_ENFORCE_EQ(pads.size(), strides.size());
PADDLE_ENFORCE_EQ(pads.size(), dilations.size());
PADDLE_ENFORCE_EQ(pads.size(), strides.size(),
platform::errors::InvalidArgument(
"The size of pads and strides should be equal. But "
"received size of pads is %d, size of strides is %d.",
pads.size(), strides.size()));
PADDLE_ENFORCE_EQ(
pads.size(), dilations.size(),
platform::errors::InvalidArgument(
"The size of pads and dilations should be equal. But received size "
"of pads is %d, size of dilations is %d.",
pads.size(), dilations.size()));
#if !CUDNN_VERSION_MIN(6, 0, 0)
// cudnn v5 does not support dilation conv, the argument is called upscale
// instead of dilations and it is must be one.
for (size_t i = 0; i < dilations.size(); ++i) {
PADDLE_ENFORCE_EQ(
dilations[i], 1,
"Dilations conv is not supported in this cuDNN version(%d.%d.%d).",
CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100,
CUDNN_VERSION % 100);
PADDLE_ENFORCE_EQ(dilations[i], 1,
platform::errors::InvalidArgument(
"Dilations conv is not supported in this cuDNN "
"version(%d.%d.%d).",
CUDNN_VERSION / 1000, CUDNN_VERSION % 1000 / 100,
CUDNN_VERSION % 100));
}
#endif
......@@ -377,8 +391,17 @@ class ScopedPoolingDescriptor {
const std::vector<int>& kernel,
const std::vector<int>& pads,
const std::vector<int>& strides) {
PADDLE_ENFORCE_EQ(kernel.size(), pads.size());
PADDLE_ENFORCE_EQ(kernel.size(), strides.size());
PADDLE_ENFORCE_EQ(kernel.size(), pads.size(),
platform::errors::InvalidArgument(
"The size of kernel and pads should be equal. But "
"received size of kernel is %d, size of pads is %d.",
kernel.size(), pads.size()));
PADDLE_ENFORCE_EQ(
kernel.size(), strides.size(),
platform::errors::InvalidArgument(
"The size of kernel and strides should be equal. But "
"received size of kernel is %d, size of strides is %d.",
kernel.size(), strides.size()));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetPoolingNdDescriptor(
desc_, (GetPoolingMode(mode)),
CUDNN_PROPAGATE_NAN, // Always propagate nans.
......@@ -456,8 +479,9 @@ class ScopedActivationDescriptor {
mode = CUDNN_ACTIVATION_TANH;
break;
default:
PADDLE_THROW("unrecognized activation mode: %d .",
static_cast<int>(activation_mode));
PADDLE_THROW(platform::errors::Unimplemented(
"Unrecognized CUDNN activation mode: %d.",
static_cast<int>(activation_mode)));
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetActivationDescriptor(
desc_, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling));
......
......@@ -59,12 +59,11 @@ DeviceContextPool* DeviceContextPool::pool = nullptr;
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"Place %s is not supported, Please check that your paddle compiles "
"with WITH_GPU "
"option or check that your train process hold the correct gpu_id if "
"you use Executor",
place);
PADDLE_THROW(platform::errors::Unimplemented(
"Place %s is not supported. Please check that your paddle compiles "
"with WITH_GPU option or check that your train process hold the "
"correct gpu_id if you use Executor.",
place));
}
return it->second.get().get();
}
......@@ -84,7 +83,11 @@ inline void EmplaceDeviceContext(
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_GT(
places.size(), 0,
platform::errors::InvalidArgument("The number of platform places should "
"be larger than 0. But received %d.",
places.size()));
std::set<Place> set;
for (auto& p : places) {
set.insert(p);
......@@ -100,18 +103,18 @@ DeviceContextPool::DeviceContextPool(
#ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
PADDLE_THROW(platform::errors::Unimplemented(
"'CUDAPlace is not supported. Please re-compile with WITH_GPU."
"option"));
#endif
} else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
&device_contexts_, p);
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
PADDLE_THROW(platform::errors::Unimplemented(
"'CUDAPlace' is not supported. Please re-compile with WITH_GPU."
"option"));
#endif
}
}
......
......@@ -575,7 +575,8 @@ class DeviceTracerImpl : public DeviceTracer {
} else if (platform::is_cuda_pinned_place(r.place)) {
event->set_place(proto::MemEvent::CUDAPinnedPlace);
} else {
PADDLE_THROW("The current place is not supported.");
PADDLE_THROW(platform::errors::Unimplemented(
"The current place is not supported."));
}
event->set_alloc_in(r.alloc_in);
event->set_free_in(r.free_in);
......
......@@ -319,9 +319,11 @@ void* GetMKLMLDsoHandle() {
void* GetOpDsoHandle(const std::string& dso_name) {
#if defined(__APPLE__) || defined(__OSX__)
PADDLE_THROW("Do not support Apple.");
PADDLE_THROW(platform::errors::Unimplemented(
"Create custom cpp op outside framework do not support Apple."));
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
PADDLE_THROW("Do not support Windows.");
PADDLE_THROW(platform::errors::Unimplemented(
"Create custom cpp op outside framework do not support Windows."));
#else
return GetDsoHandleFromSearchPath(FLAGS_op_dir, dso_name);
#endif
......
......@@ -20,40 +20,46 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
PADDLE_ENFORCE(true, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE is ok %d now %f.", 123, 0.345));
size_t val = 1;
const size_t limit = 10;
PADDLE_ENFORCE(val < limit, "Enforce is OK too");
PADDLE_ENFORCE(val < limit, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE tests failed."));
}
TEST(ENFORCE, FAILED) {
bool caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work %d at all.", 123));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("Enforce is not ok 123 at all") !=
EXPECT_TRUE(ex_msg.find("PADDLE_ENFORCE won't work 123 at all.") !=
std::string::npos);
}
EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok at all");
PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work at all."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("Enforce is not ok at all") != std::string::npos);
EXPECT_TRUE(ex_msg.find("PADDLE_ENFORCE won't work at all.") !=
std::string::npos);
}
EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false);
PADDLE_ENFORCE(false, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE won't work at all."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
EXPECT_NE(std::string(error.what()).find(" at "), 0UL);
EXPECT_NE(std::string(error.what()).find(" at "), 0UL);
}
EXPECT_TRUE(caught_exception);
}
......@@ -61,9 +67,11 @@ TEST(ENFORCE, FAILED) {
TEST(ENFORCE, NO_ARG_OK) {
int a = 2;
int b = 2;
PADDLE_ENFORCE_EQ(a, b);
PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
// test enforce with extra message.
PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info");
PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"Some %s wrong in PADDLE_ENFORCE_EQ.", "info"));
}
TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
......@@ -71,7 +79,7 @@ TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument(
"the result is not equal correct result."));
"The result is not equal correct result."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -86,7 +94,7 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, 1 + 3, paddle::platform::errors::InvalidArgument(
"the result is not equal correct result."));
"The result is not equal correct result."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -98,15 +106,19 @@ TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) {
}
TEST(ENFORCE_NE, OK) {
PADDLE_ENFORCE_NE(1, 2);
PADDLE_ENFORCE_NE(1.0, 2UL);
PADDLE_ENFORCE_NE(1, 2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_NE tests failed."));
PADDLE_ENFORCE_NE(1.0, 2UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_NE tests failed."));
}
TEST(ENFORCE_NE, FAIL) {
bool caught_exception = false;
try {
// 2UL here to check data type compatible
PADDLE_ENFORCE_NE(1.0, 1UL);
PADDLE_ENFORCE_NE(1.0, 1UL,
paddle::platform::errors::Unavailable(
"Expected 1.0 != 1UL, but received 1.0:1 == 1UL:1."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -116,11 +128,15 @@ TEST(ENFORCE_NE, FAIL) {
EXPECT_TRUE(caught_exception);
}
TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); }
TEST(ENFORCE_GT, OK) {
PADDLE_ENFORCE_GT(2, 1, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GT tests failed."));
}
TEST(ENFORCE_GT, FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_GT(1, 2);
PADDLE_ENFORCE_GT(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 > 2, but received 1:1 <= 2:2."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -131,14 +147,18 @@ TEST(ENFORCE_GT, FAIL) {
}
TEST(ENFORCE_GE, OK) {
PADDLE_ENFORCE_GE(2, 2);
PADDLE_ENFORCE_GE(3, 2);
PADDLE_ENFORCE_GE(3.21, 2.0);
PADDLE_ENFORCE_GE(2, 2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GE tests failed."));
PADDLE_ENFORCE_GE(3, 2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GE tests failed."));
PADDLE_ENFORCE_GE(3.21, 2.0, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_GE tests failed."));
}
TEST(ENFORCE_GE, FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_GE(1, 2);
PADDLE_ENFORCE_GE(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 >= 2, but received 1:1 < 2:2."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -149,16 +169,22 @@ TEST(ENFORCE_GE, FAIL) {
}
TEST(ENFORCE_LE, OK) {
PADDLE_ENFORCE_LE(1, 1);
PADDLE_ENFORCE_LE(1UL, 1UL);
PADDLE_ENFORCE_LE(2, 3);
PADDLE_ENFORCE_LE(2UL, 3UL);
PADDLE_ENFORCE_LE(2.0, 3.2);
PADDLE_ENFORCE_LE(1, 1, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(1UL, 1UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2, 3, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2UL, 3UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
PADDLE_ENFORCE_LE(2.0, 3.2, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LE tests failed."));
}
TEST(ENFORCE_LE, FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_GT(1, 2);
PADDLE_ENFORCE_GT(1, 2, paddle::platform::errors::InvalidArgument(
"Expected 1 > 2, but received 1:1 <= 2:2."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -169,14 +195,20 @@ TEST(ENFORCE_LE, FAIL) {
}
TEST(ENFORCE_LT, OK) {
PADDLE_ENFORCE_LT(3, 10);
PADDLE_ENFORCE_LT(2UL, 3UL);
PADDLE_ENFORCE_LT(2, 3);
PADDLE_ENFORCE_LT(3, 10, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LT tests failed."));
PADDLE_ENFORCE_LT(2UL, 3UL, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LT tests failed."));
PADDLE_ENFORCE_LT(2, 3, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_LT tests failed."));
}
TEST(ENFORCE_LT, FAIL) {
bool caught_exception = false;
try {
PADDLE_ENFORCE_LT(1UL, 0.12);
PADDLE_ENFORCE_LT(
1UL, 0.12,
paddle::platform::errors::InvalidArgument(
"Expected 1UL < 0.12, but received 1UL:1 >= 0.12:0.12."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
......@@ -189,18 +221,20 @@ TEST(ENFORCE_LT, FAIL) {
TEST(ENFORCE_NOT_NULL, OK) {
int* a = new int;
PADDLE_ENFORCE_NOT_NULL(a);
PADDLE_ENFORCE_NOT_NULL(a, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_NOT_NULL tests failed."));
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool caught_exception = false;
try {
int* a = nullptr;
PADDLE_ENFORCE_NOT_NULL(a);
PADDLE_ENFORCE_NOT_NULL(
a, paddle::platform::errors::Unavailable("The a should not be null."));
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
std::string ex_msg = error.what();
EXPECT_TRUE(ex_msg.find("a should not be null") != std::string::npos);
EXPECT_TRUE(ex_msg.find("The a should not be null.") != std::string::npos);
}
EXPECT_TRUE(caught_exception);
}
......@@ -233,14 +267,16 @@ std::ostream& operator<<(std::ostream& os, const Dims& d) {
TEST(ENFORCE_USER_DEFINED_CLASS, EQ) {
Dims a{{1, 2, 3, 4}}, b{{1, 2, 3, 4}};
PADDLE_ENFORCE_EQ(a, b);
PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
}
TEST(ENFORCE_USER_DEFINED_CLASS, NE) {
Dims a{{1, 2, 3, 4}}, b{{5, 6, 7, 8}};
bool caught_exception = false;
try {
PADDLE_ENFORCE_EQ(a, b);
PADDLE_ENFORCE_EQ(a, b, paddle::platform::errors::Unavailable(
"PADDLE_ENFORCE_EQ tests failed."));
} catch (paddle::platform::EnforceNotMet&) {
caught_exception = true;
}
......@@ -329,12 +365,15 @@ TEST(enforce, cannot_to_string_type) {
"int can be converted to string");
CannotToStringType obj1(3), obj2(4), obj3(3);
PADDLE_ENFORCE_NE(obj1, obj2, "Object 1 is not equal to Object 2");
PADDLE_ENFORCE_EQ(obj1, obj3, "Object 1 is equal to Object 3");
PADDLE_ENFORCE_NE(obj1, obj2, paddle::platform::errors::InvalidArgument(
"Object 1 is not equal to Object 2"));
PADDLE_ENFORCE_EQ(obj1, obj3, paddle::platform::errors::InvalidArgument(
"Object 1 is equal to Object 3"));
std::string msg = "Compare obj1 with obj2";
try {
PADDLE_ENFORCE_EQ(obj1, obj2, msg);
PADDLE_ENFORCE_EQ(obj1, obj2,
paddle::platform::errors::InvalidArgument(msg));
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
......@@ -347,7 +386,7 @@ TEST(enforce, cannot_to_string_type) {
msg = "Compare x with y";
try {
int x = 3, y = 2;
PADDLE_ENFORCE_EQ(x, y, msg);
PADDLE_ENFORCE_EQ(x, y, paddle::platform::errors::InvalidArgument(msg));
} catch (paddle::platform::EnforceNotMet& error) {
std::string ex_msg = error.what();
LOG(INFO) << ex_msg;
......@@ -357,14 +396,22 @@ TEST(enforce, cannot_to_string_type) {
}
std::set<int> set;
PADDLE_ENFORCE_EQ(set.begin(), set.end());
PADDLE_ENFORCE_EQ(set.begin(), set.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of set is not equal."));
set.insert(3);
PADDLE_ENFORCE_NE(set.begin(), set.end());
PADDLE_ENFORCE_NE(set.begin(), set.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of set is equal."));
std::list<float> list;
PADDLE_ENFORCE_EQ(list.begin(), list.end());
PADDLE_ENFORCE_EQ(list.begin(), list.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of list is not equal."));
list.push_back(4);
PADDLE_ENFORCE_NE(list.begin(), list.end());
PADDLE_ENFORCE_NE(list.begin(), list.end(),
paddle::platform::errors::InvalidArgument(
"The begin and end of list is equal."));
}
TEST(GET_DATA_SAFELY_MACRO, SUCCESS) {
......
......@@ -145,7 +145,9 @@ TEST(float16, lod_tensor_cpu) {
TEST(float16, floating) {
// compile time assert.
PADDLE_ENFORCE_EQ(std::is_floating_point<float16>::value, true);
PADDLE_ENFORCE_EQ(
std::is_floating_point<float16>::value, true,
platform::errors::Unavailable("The float16 support in CPU failed."));
}
TEST(float16, print) {
......
......@@ -261,8 +261,12 @@ TEST(float16, typeid) {
int b(0);
// compile time assert
PADDLE_ENFORCE_EQ(functor(a), true);
PADDLE_ENFORCE_EQ(functor2(b), false);
PADDLE_ENFORCE_EQ(
functor(a), true,
platform::errors::Unavailable("The float16 support in GPU failed."));
PADDLE_ENFORCE_EQ(
functor2(b), false,
platform::errors::Unavailable("The float16 support in GPU failed."));
}
// GPU test
......
......@@ -243,7 +243,9 @@ size_t GpuMaxAllocSize() {
static size_t GpuAllocSize(bool realloc) {
size_t available_to_alloc = GpuAvailableMemToAlloc();
PADDLE_ENFORCE_GT(available_to_alloc, 0, "No enough available GPU memory");
PADDLE_ENFORCE_GT(
available_to_alloc, 0,
platform::errors::ResourceExhausted("Not enough available GPU memory."));
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// allocated by fraction
size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
......@@ -251,8 +253,9 @@ static size_t GpuAllocSize(bool realloc) {
size_t alloc_bytes =
(flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
FLAGS_fraction_of_gpu_memory_to_use);
PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes,
"No enough available GPU memory");
PADDLE_ENFORCE_GE(
available_to_alloc, alloc_bytes,
platform::errors::ResourceExhausted("Not enough available GPU memory."));
VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
<< " MiB, is it Re-alloc: " << realloc;
return alloc_bytes;
......
......@@ -98,9 +98,8 @@ void InitP2P(std::vector<int> devices) {
for (int j = 0; j < count; ++j) {
if (devices[i] == devices[j]) continue;
int can_acess = -1;
PADDLE_ENFORCE(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
"Failed to test P2P access.");
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]));
if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
......
......@@ -172,7 +172,9 @@ class MKLDNNHandlerT {
const std::string key_fwd_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@backward_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
......@@ -1450,8 +1452,10 @@ static void SetDstMemoryQuantized(
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size();
MKLDNNMemoryFormat dst_fmt;
PADDLE_ENFORCE_LE(dst_dims, 5,
"Dst memory for quantization can not have dims > 5");
PADDLE_ENFORCE_LE(dst_dims, 5, platform::errors::InvalidArgument(
"Dst memory for quantization can not have "
"dims > 5. But received dst_dims is %d.",
dst_dims));
dst_fmt = platform::MKLDNNFormatForSize(dst_dims, output_format);
auto tmp_dst_md = platform::MKLDNNMemDesc(
......
......@@ -46,7 +46,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
} else if (type == framework::proto::VarType::FP16) {
return ncclFloat16;
} else {
PADDLE_THROW("Not supported");
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
}
}
......@@ -95,7 +96,8 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE_EQ(!places.empty(), true);
PADDLE_ENFORCE_EQ(!places.empty(), true, platform::errors::InvalidArgument(
"The NCCL place is empty."));
order_.reserve(places.size());
for (auto &p : places) {
int dev_id = BOOST_GET_CONST(CUDAPlace, p).device;
......@@ -104,7 +106,8 @@ struct NCCLContextMap {
}
PADDLE_ENFORCE_EQ(
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
platform::errors::Unavailable("NCCL Context Map does not support "
"contain two or more same device."));
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms.
......@@ -113,7 +116,8 @@ struct NCCLContextMap {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data()));
} else {
PADDLE_ENFORCE_NOT_NULL(nccl_id);
PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument(
"The NCCL id should not be null."));
{
int nranks = num_trainers * order_.size();
NCCLGroupGuard gurad;
......@@ -263,13 +267,17 @@ class NCCLCommunicator {
size_t trainers_num, size_t trainer_id,
size_t inter_trainers_num,
size_t exter_trainers_num) {
PADDLE_ENFORCE_EQ(trainers_num, inter_trainers_num * exter_trainers_num,
"trainers_num:%llu != inter_trainers_num:%llu * "
"exter_trainers_num:%llu",
trainers_num, inter_trainers_num, exter_trainers_num);
PADDLE_ENFORCE_EQ(
trainers_num, inter_trainers_num * exter_trainers_num,
platform::errors::InvalidArgument(
"trainers_num:%llu != inter_trainers_num:%llu * "
"exter_trainers_num:%llu",
trainers_num, inter_trainers_num, exter_trainers_num));
PADDLE_ENFORCE_GT(inter_trainers_num, 1, "inter_trainers_num:%llu must > 1",
inter_trainers_num);
PADDLE_ENFORCE_GT(
inter_trainers_num, 1,
platform::errors::InvalidArgument("inter_trainers_num:%llu must > 1",
inter_trainers_num));
int inter_trainer_id = trainer_id % inter_trainers_num;
for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
......@@ -300,14 +308,16 @@ class NCCLCommunicator {
bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
PADDLE_ENFORCE(h_inter_ctxs_.size() > 0,
"must init hierarchical ctxs first!");
PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
}
NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
PADDLE_ENFORCE(h_exter_ctxs_.size() > 0,
"must init hierarchical ctxs first!");
PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
platform::errors::InvalidArgument(
"Hierarchical ctxs should be initialized firstly!"));
return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册