提交 3424a4c0 编写于 作者: G gangliao 提交者: luotao1

Fix bug and redundant code in hl_dso_loader.cc (#306)

上级 ed83a1d6
...@@ -44,62 +44,25 @@ void* cudnn_dso_handle = nullptr; ...@@ -44,62 +44,25 @@ void* cudnn_dso_handle = nullptr;
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ #define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
cudnnStatus_t operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
typedef cudnnStatus_t (*cudnnFunc)(Args...); \ using cudnn_func = decltype(__name(args...))(*)(Args...); \
std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, \ std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, \
&cudnn_dso_handle); \ &cudnn_dso_handle); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnnFunc>(p_##__name)(args...); \ return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ } __name; /* struct DynLoad__##__name */
struct DynLoad__cudnnGetVersion {
template <typename... Args>
size_t operator()(Args... args) {
typedef size_t (*cudnnFunc)(Args...);
std::call_once(cudnn_dso_flag, GetCudnnDsoHandle,
&cudnn_dso_handle);
void* p_name = dlsym(cudnn_dso_handle, "cudnnGetVersion");
return reinterpret_cast<cudnnFunc>(p_name)(args...);
}
} cudnnGetVersion; /* struct DynLoad__##__name */
struct DynLoad__cudnnGetErrorString {
template <typename... Args>
const char* operator()(Args... args) {
typedef const char* (*cudnnFunc)(Args...);
std::call_once(cudnn_dso_flag, GetCudnnDsoHandle,
&cudnn_dso_handle);
void* p_name = dlsym(cudnn_dso_handle, "cudnnGetErrorString");
return reinterpret_cast<cudnnFunc>(p_name)(args...);
}
} cudnnGetErrorString; /* struct DynLoad__##__name */
#else #else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ #define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
cudnnStatus_t operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ } __name; /* struct DynLoad__##__name */
struct DynLoad__cudnnGetVersion {
template <typename... Args>
size_t operator()(Args... args) {
return cudnnGetVersion(args...);
}
} cudnnGetVersion; /* struct DynLoad__##__name */
struct DynLoad__cudnnGetErrorString {
template <typename... Args>
const char* operator()(Args... args) {
return cudnnGetErrorString(args...);
}
} cudnnGetErrorString; /* struct DynLoad__##__name */
#endif #endif
/** /**
...@@ -133,7 +96,9 @@ struct DynLoad__cudnnGetErrorString { ...@@ -133,7 +96,9 @@ struct DynLoad__cudnnGetErrorString {
__macro(cudnnPoolingForward) \ __macro(cudnnPoolingForward) \
__macro(cudnnPoolingBackward) \ __macro(cudnnPoolingBackward) \
__macro(cudnnSoftmaxBackward) \ __macro(cudnnSoftmaxBackward) \
__macro(cudnnSoftmaxForward) __macro(cudnnSoftmaxForward) \
__macro(cudnnGetVersion) \
__macro(cudnnGetErrorString)
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
......
...@@ -85,44 +85,24 @@ void* cudart_dso_handle = nullptr; ...@@ -85,44 +85,24 @@ void* cudart_dso_handle = nullptr;
#define DYNAMIC_LOAD_CUDART_WRAP(__name) \ #define DYNAMIC_LOAD_CUDART_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
cudaError_t operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
typedef cudaError_t (*cudartFunc)(Args...); \ using cudart_func = decltype(__name(args...))(*)(Args...); \
std::call_once(cudart_dso_flag, GetCudartDsoHandle, \ std::call_once(cudart_dso_flag, GetCudartDsoHandle, \
&cudart_dso_handle); \ &cudart_dso_handle); \
void* p_##__name = dlsym(cudart_dso_handle, #__name); \ void* p_##__name = dlsym(cudart_dso_handle, #__name); \
return reinterpret_cast<cudartFunc>(p_##__name)(args...); \ return reinterpret_cast<cudart_func>(p_##__name)(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ } __name; /* struct DynLoad__##__name */
#else #else
#define DYNAMIC_LOAD_CUDART_WRAP(__name) \ #define DYNAMIC_LOAD_CUDART_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
cudaError_t operator()(Args... args) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \ return __name(args...); \
} \ } \
} __name; /* struct DynLoad__##__name */ } __name; /* struct DynLoad__##__name */
#endif #endif
#ifdef PADDLE_USE_DSO
struct DynLoad__cudaGetErrorString {
template <typename... Args>
const char* operator()(Args... args) {
typedef const char* (*cudaFunc)(Args...);
std::call_once(cudart_dso_flag, GetCudartDsoHandle,
&cudart_dso_handle);
void* p_func = dlsym(cudart_dso_handle, "cudaGetErrorString");
return reinterpret_cast<cudaFunc>(p_func)(args...);
}
} cudaGetErrorString; /* struct DynLoad__cudaGetErrorString */
#else
struct DynLoad__cudaGetErrorString {
template <typename... Args>
const char* operator()(Args... args) {
return cudaGetErrorString(args...);
}
} cudaGetErrorString; /* struct DynLoad__cudaGetErrorString */
#endif
/* include all needed cuda functions in HPPL */ /* include all needed cuda functions in HPPL */
#define CUDA_ROUTINE_EACH(__macro) \ #define CUDA_ROUTINE_EACH(__macro) \
__macro(cudaMalloc) \ __macro(cudaMalloc) \
...@@ -152,7 +132,8 @@ struct DynLoad__cudaGetErrorString { ...@@ -152,7 +132,8 @@ struct DynLoad__cudaGetErrorString {
__macro(cudaSetDeviceFlags) \ __macro(cudaSetDeviceFlags) \
__macro(cudaGetLastError) \ __macro(cudaGetLastError) \
__macro(cudaFuncSetCacheConfig) \ __macro(cudaFuncSetCacheConfig) \
__macro(cudaRuntimeGetVersion) __macro(cudaRuntimeGetVersion) \
__macro(cudaGetErrorString)
CUDA_ROUTINE_EACH(DYNAMIC_LOAD_CUDART_WRAP) CUDA_ROUTINE_EACH(DYNAMIC_LOAD_CUDART_WRAP)
......
...@@ -49,14 +49,14 @@ static inline std::string join(const std::string& part1, const std::string& part ...@@ -49,14 +49,14 @@ static inline std::string join(const std::string& part1, const std::string& part
static inline void GetDsoHandleFromDefaultPath( static inline void GetDsoHandleFromDefaultPath(
std::string& dso_path, void** dso_handle, int dynload_flags) { std::string& dso_path, void** dso_handle, int dynload_flags) {
LOG(INFO) << "Try to find cuda library: " << dso_path LOG(INFO) << "Try to find cuda library: " << dso_path
<< "from default system path."; << " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
*dso_handle = dlopen(dso_path.c_str(), dynload_flags); *dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to // DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
// bring System Integrity Projection (SIP), if dso_handle // bring System Integrity Projection (SIP), if dso_handle
// is null, search from default package path in Mac OS. // is null, search from default package path in Mac OS.
#if defined(__APPLE__) or defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
if (nullptr == *dso_handle) { if (nullptr == *dso_handle) {
dso_path = join("/usr/local/cuda/lib/", dso_path); dso_path = join("/usr/local/cuda/lib/", dso_path);
*dso_handle = dlopen(dso_path.c_str(), dynload_flags); *dso_handle = dlopen(dso_path.c_str(), dynload_flags);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册