未验证 提交 c02ba51d 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10191 from reyoung/feature/strict_dynload

Make dyload strictly use the same ABI in header
......@@ -14,10 +14,12 @@
#pragma once
#include <cublasXt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <dlfcn.h>
#include <mutex> // NOLINT
#include <type_traits>
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle {
......@@ -37,14 +39,14 @@ extern void *cublas_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
using FUNC_TYPE = decltype(&::__name); \
template <typename... Args> \
inline cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
......@@ -71,8 +73,8 @@ extern void *cublas_dso_handle;
__macro(cublasDgemm_v2); \
__macro(cublasHgemm); \
__macro(cublasSgemmEx); \
__macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \
__macro(cublasSgeam); \
__macro(cublasDgeam); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
......
......@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \
using cudnn_func = decltype(&::__name); \
std::call_once(cudnn_dso_flag, []() { \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \
......
......@@ -41,7 +41,7 @@ extern void *cupti_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
inline CUptiResult CUPTIAPI operator()(Args... args) { \
typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \
using cuptiFunc = decltype(&::__name); \
std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \
......
......@@ -30,7 +30,7 @@ extern void *curand_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \
using curandFunc = decltype(&::__name); \
std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \
......
......@@ -33,7 +33,7 @@ extern void* nccl_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \
using nccl_func = decltype(&::__name); \
std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \
......
......@@ -36,7 +36,7 @@ extern void* warpctc_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
using warpctcFunc = decltype(&::__name); \
std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册