提交 ed5bd5e5 编写于 作者: P peizhilin

test=develop

上级 19ebd8b4
...@@ -16,9 +16,7 @@ if (CUPTI_FOUND) ...@@ -16,9 +16,7 @@ if (CUPTI_FOUND)
list(APPEND CUDA_SRCS cupti.cc) list(APPEND CUDA_SRCS cupti.cc)
endif(CUPTI_FOUND) endif(CUPTI_FOUND)
nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader)
if (NOT WIN32)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
endif(NOT WIN32)
if (WITH_MKLML) if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml)
endif() endif()
......
...@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cudnn_func = decltype(&::__name); \ using cudnn_func = decltype(&::__name); \
std::call_once(cudnn_dso_flag, []() { \ std::call_once(cudnn_dso_flag, []() { \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
......
...@@ -201,6 +201,8 @@ void* GetCurandDsoHandle() { ...@@ -201,6 +201,8 @@ void* GetCurandDsoHandle() {
void* GetWarpCTCDsoHandle() { void* GetWarpCTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib"); return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "warpctc.dll");
#else #else
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so"); return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so");
#endif #endif
......
...@@ -18,6 +18,12 @@ namespace paddle { ...@@ -18,6 +18,12 @@ namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
#ifndef _WIN32
#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__))
#else
#define DECLARE_TYPE(__name, ...) decltype(auto)
#endif
void* GetCublasDsoHandle(); void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle(); void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle(); void* GetCUPTIDsoHandle();
......
...@@ -34,7 +34,7 @@ extern void* mklml_dso_handle; ...@@ -34,7 +34,7 @@ extern void* mklml_dso_handle;
#define DYNAMIC_LOAD_MKLML_WRAP(__name) \ #define DYNAMIC_LOAD_MKLML_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklmlFunc = decltype(&::__name); \ using mklmlFunc = decltype(&::__name); \
std::call_once(mklml_dso_flag, []() { \ std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \ mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
......
...@@ -33,7 +33,7 @@ extern void* tensorrt_dso_handle; ...@@ -33,7 +33,7 @@ extern void* tensorrt_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(__name(args...)) (*)(Args...); \ using tensorrt_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(tensorrt_dso_flag, []() { \ std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = \ tensorrt_dso_handle = \
......
...@@ -34,7 +34,7 @@ extern void* warpctc_dso_handle; ...@@ -34,7 +34,7 @@ extern void* warpctc_dso_handle;
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \ #define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using warpctcFunc = decltype(&::__name); \ using warpctcFunc = decltype(&::__name); \
std::call_once(warpctc_dso_flag, []() { \ std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册