提交 ed5bd5e5 编写于 作者: P peizhilin

test=develop

上级 19ebd8b4
......@@ -16,9 +16,7 @@ if (CUPTI_FOUND)
list(APPEND CUDA_SRCS cupti.cc)
endif(CUPTI_FOUND)
nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader)
if (NOT WIN32)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
endif(NOT WIN32)
if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml)
endif()
......
......@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cudnn_func = decltype(&::__name); \
std::call_once(cudnn_dso_flag, []() { \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
......
......@@ -201,6 +201,8 @@ void* GetCurandDsoHandle() {
void* GetWarpCTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "warpctc.dll");
#else
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so");
#endif
......
......@@ -18,6 +18,12 @@ namespace paddle {
namespace platform {
namespace dynload {
#ifndef _WIN32
#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__))
#else
#define DECLARE_TYPE(__name, ...) decltype(auto)
#endif
void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle();
......
......@@ -34,7 +34,7 @@ extern void* mklml_dso_handle;
#define DYNAMIC_LOAD_MKLML_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklmlFunc = decltype(&::__name); \
std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
......
......@@ -33,7 +33,7 @@ extern void* tensorrt_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = \
......
......@@ -34,7 +34,7 @@ extern void* warpctc_dso_handle;
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using warpctcFunc = decltype(&::__name); \
std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册