提交 b3a3e6f6 编写于 作者: C Chen Weihang 提交者: Tao Luo

change cuda enforce & add example (#21142)

上级 37e0e7a9
......@@ -81,7 +81,8 @@ class CUDADeviceContextAllocator : public Allocator {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreate(&event_, cudaEventDisableTiming),
"Create event failed in CUDADeviceContextAllocator");
platform::errors::External(
"Create event failed in CUDADeviceContextAllocator"));
}
~CUDADeviceContextAllocator() {
......
......@@ -373,6 +373,10 @@ struct EOFException : public std::exception {
inline bool is_error(cudaError_t e) { return e != cudaSuccess; }
inline std::string build_ex_string(cudaError_t e, const std::string& msg) {
return msg;
}
inline void throw_on_error(cudaError_t e, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(e, thrust::cuda_category(), msg);
......@@ -385,6 +389,11 @@ inline bool is_error(curandStatus_t stat) {
return stat != CURAND_STATUS_SUCCESS;
}
inline std::string build_ex_string(curandStatus_t stat,
const std::string& msg) {
return msg;
}
inline void throw_on_error(curandStatus_t stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
......@@ -398,11 +407,15 @@ inline bool is_error(cudnnStatus_t stat) {
return stat != CUDNN_STATUS_SUCCESS;
}
inline std::string build_ex_string(cudnnStatus_t stat, const std::string& msg) {
return msg + "\n [" + platform::dynload::cudnnGetErrorString(stat) + "]";
}
inline void throw_on_error(cudnnStatus_t stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) + msg);
throw std::runtime_error(msg);
#else
LOG(FATAL) << platform::dynload::cudnnGetErrorString(stat) << msg;
LOG(FATAL) << msg;
#endif
}
......@@ -410,31 +423,36 @@ inline bool is_error(cublasStatus_t stat) {
return stat != CUBLAS_STATUS_SUCCESS;
}
inline void throw_on_error(cublasStatus_t stat, const std::string& msg) {
inline std::string build_ex_string(cublasStatus_t stat,
const std::string& msg) {
std::string err;
if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
err = "CUBLAS: not initialized, ";
err = "CUBLAS: not initialized.";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
err = "CUBLAS: alloc failed, ";
err = "CUBLAS: alloc failed.";
} else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
err = "CUBLAS: invalid value, ";
err = "CUBLAS: invalid value.";
} else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
err = "CUBLAS: arch mismatch, ";
err = "CUBLAS: arch mismatch.";
} else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
err = "CUBLAS: mapping error, ";
err = "CUBLAS: mapping error.";
} else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
err = "CUBLAS: execution failed, ";
err = "CUBLAS: execution failed.";
} else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
err = "CUBLAS: internal error, ";
err = "CUBLAS: internal error.";
} else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
err = "CUBLAS: not supported, ";
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, ";
err = "CUBLAS: license error.";
}
return msg + "\n [" + err + "]";
}
inline void throw_on_error(cublasStatus_t stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(err + msg);
throw std::runtime_error(msg);
#else
LOG(FATAL) << err << msg;
LOG(FATAL) << msg;
#endif
}
......@@ -443,11 +461,17 @@ inline bool is_error(ncclResult_t nccl_result) {
return nccl_result != ncclSuccess;
}
inline void throw_on_error(ncclResult_t stat, const std::string& msg) {
inline std::string build_ex_string(ncclResult_t nccl_result,
const std::string& msg) {
return msg + "\n [" + platform::dynload::ncclGetErrorString(nccl_result) +
"]";
}
inline void throw_on_error(ncclResult_t nccl_result, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) + msg);
throw std::runtime_error(msg);
#else
LOG(FATAL) << platform::dynload::ncclGetErrorString(stat) << msg;
LOG(FATAL) << msg;
#endif
}
#endif // __APPLE__ and windows
......@@ -480,22 +504,25 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \
do { \
auto __cond__ = (COND); \
using __CUDA_STATUS_TYPE__ = decltype(__cond__); \
constexpr auto __success_type__ = \
::paddle::platform::details::CudaStatusType< \
__CUDA_STATUS_TYPE__>::kSuccess; \
if (UNLIKELY(__cond__ != __success_type__)) { \
try { \
::paddle::platform::throw_on_error( \
__cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} \
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \
do { \
auto __cond__ = (COND); \
using __CUDA_STATUS_TYPE__ = decltype(__cond__); \
constexpr auto __success_type__ = \
::paddle::platform::details::CudaStatusType< \
__CUDA_STATUS_TYPE__>::kSuccess; \
if (UNLIKELY(__cond__ != __success_type__)) { \
try { \
::paddle::platform::throw_on_error( \
__cond__, \
::paddle::platform::build_ex_string( \
__cond__, \
::paddle::platform::ErrorSummary(__VA_ARGS__).ToString())); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} \
} while (0)
#undef DEFINE_CUDA_STATUS_TYPE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册