未验证 提交 e185502e 编写于 作者: Y Yi Wang 提交者: GitHub

Fix cpplint errors with paddle/fluid/platform/dynload (#9715)

* Update source files.

* Update headers

* Update

* Update

* Update

* Update

* Fix a CMake dependency
上级 544254fe
...@@ -62,7 +62,8 @@ ExternalProject_Add( ...@@ -62,7 +62,8 @@ ExternalProject_Add(
) )
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include warpctc headers.
ADD_LIBRARY(warpctc SHARED IMPORTED GLOBAL) ADD_LIBRARY(warpctc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES}) SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES})
......
...@@ -16,6 +16,6 @@ else() ...@@ -16,6 +16,6 @@ else()
endif() endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps}) scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
...@@ -40,9 +40,9 @@ extern void *cublas_dso_handle; ...@@ -40,9 +40,9 @@ extern void *cublas_dso_handle;
template <typename... Args> \ template <typename... Args> \
inline cublasStatus_t operator()(Args... args) { \ inline cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \ typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \ std::call_once(cublas_dso_flag, []() { \
paddle::platform::dynload::GetCublasDsoHandle, \ cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
&cublas_dso_handle); \ }); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \ void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \ return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \ } \
......
...@@ -44,7 +44,8 @@ CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP); ...@@ -44,7 +44,8 @@ CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP);
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
bool HasCUDNN() { bool HasCUDNN() {
std::call_once(cudnn_dso_flag, GetCUDNNDsoHandle, &cudnn_dso_handle); std::call_once(cudnn_dso_flag,
[]() { cudnn_dso_handle = GetCUDNNDsoHandle(); });
return cudnn_dso_handle != nullptr; return cudnn_dso_handle != nullptr;
} }
......
...@@ -35,9 +35,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -35,9 +35,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \ using cudnn_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(cudnn_dso_flag, \ std::call_once(cudnn_dso_flag, []() { \
paddle::platform::dynload::GetCUDNNDsoHandle, \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
&cudnn_dso_handle); \ }); \
EnforceCUDNNLoaded(#__name); \ EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \ return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
......
...@@ -11,14 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,14 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_CUPTI #ifdef PADDLE_WITH_CUPTI
#include <cuda.h> #include <cuda.h>
#include <cupti.h> #include <cupti.h>
#include <dlfcn.h> #include <dlfcn.h>
#include <mutex> #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle { namespace paddle {
...@@ -41,9 +42,9 @@ extern void *cupti_dso_handle; ...@@ -41,9 +42,9 @@ extern void *cupti_dso_handle;
template <typename... Args> \ template <typename... Args> \
inline CUptiResult CUPTIAPI operator()(Args... args) { \ inline CUptiResult CUPTIAPI operator()(Args... args) { \
typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \ typedef CUptiResult CUPTIAPI (*cuptiFunc)(Args...); \
std::call_once(cupti_dso_flag, \ std::call_once(cupti_dso_flag, []() { \
paddle::platform::dynload::GetCUPTIDsoHandle, \ cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
&cupti_dso_handle); \ }); \
void *p_##__name = dlsym(cupti_dso_handle, #__name); \ void *p_##__name = dlsym(cupti_dso_handle, #__name); \
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \ return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \ } \
......
...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <curand.h> #include <curand.h>
#include <dlfcn.h> #include <dlfcn.h>
#include <mutex>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle { namespace paddle {
...@@ -30,9 +31,9 @@ extern void *curand_dso_handle; ...@@ -30,9 +31,9 @@ extern void *curand_dso_handle;
template <typename... Args> \ template <typename... Args> \
curandStatus_t operator()(Args... args) { \ curandStatus_t operator()(Args... args) { \
typedef curandStatus_t (*curandFunc)(Args...); \ typedef curandStatus_t (*curandFunc)(Args...); \
std::call_once(curand_dso_flag, \ std::call_once(curand_dso_flag, []() { \
paddle::platform::dynload::GetCurandDsoHandle, \ curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
&curand_dso_handle); \ }); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \ void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \ return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \ } \
......
...@@ -11,12 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,12 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include <dlfcn.h> #include <dlfcn.h>
#include <memory> #include <memory>
#include <mutex> #include <mutex> // NOLINT
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/dynload/cupti_lib_path.h"
...@@ -65,22 +67,21 @@ static inline std::string join(const std::string& part1, ...@@ -65,22 +67,21 @@ static inline std::string join(const std::string& part1,
return ret; return ret;
} }
static inline void GetDsoHandleFromDefaultPath(std::string& dso_path, static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path,
void** dso_handle,
int dynload_flags) { int dynload_flags) {
VLOG(3) << "Try to find library: " << dso_path VLOG(3) << "Try to find library: " << dso_path
<< " from default system path."; << " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
*dso_handle = dlopen(dso_path.c_str(), dynload_flags); void* dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to // DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
// bring System Integrity Projection (SIP), if dso_handle // bring System Integrity Projection (SIP), if dso_handle
// is null, search from default package path in Mac OS. // is null, search from default package path in Mac OS.
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
if (nullptr == *dso_handle) { if (nullptr == dso_handle) {
dso_path = join("/usr/local/cuda/lib/", dso_path); dso_handle =
*dso_handle = dlopen(dso_path.c_str(), dynload_flags); dlopen(join("/usr/local/cuda/lib/", dso_path).c_str(), dynload_flags);
if (nullptr == *dso_handle) { if (nullptr == dso_handle) {
if (dso_path == "libcudnn.dylib") { if (dso_path == "libcudnn.dylib") {
LOG(WARNING) << "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n " LOG(WARNING) << "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n "
"For instance, sudo tar -xzf " "For instance, sudo tar -xzf "
...@@ -91,28 +92,29 @@ static inline void GetDsoHandleFromDefaultPath(std::string& dso_path, ...@@ -91,28 +92,29 @@ static inline void GetDsoHandleFromDefaultPath(std::string& dso_path,
} }
} }
#endif #endif
return dso_handle;
} }
static inline void GetDsoHandleFromSearchPath(const std::string& search_root, static inline void* GetDsoHandleFromSearchPath(const std::string& search_root,
const std::string& dso_name, const std::string& dso_name,
void** dso_handle,
bool throw_on_error = true) { bool throw_on_error = true) {
int dynload_flags = RTLD_LAZY | RTLD_LOCAL; int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
*dso_handle = nullptr; void* dso_handle = nullptr;
std::string dlPath = dso_name; std::string dlPath = dso_name;
if (search_root.empty()) { if (search_root.empty()) {
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); dso_handle = GetDsoHandleFromDefaultPath(dlPath, dynload_flags);
} else { } else {
// search xxx.so from custom path // search xxx.so from custom path
dlPath = join(search_root, dso_name); dlPath = join(search_root, dso_name);
*dso_handle = dlopen(dlPath.c_str(), dynload_flags); dso_handle = dlopen(dlPath.c_str(), dynload_flags);
// if not found, search from default path // if not found, search from default path
if (nullptr == *dso_handle) { if (nullptr == dso_handle) {
LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " (" LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " ("
<< dlerror() << ")"; << dlerror() << ")";
dlPath = dso_name; dlPath = dso_name;
GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); dso_handle = GetDsoHandleFromDefaultPath(dlPath, dynload_flags);
} }
} }
auto error_msg = auto error_msg =
...@@ -124,70 +126,71 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root, ...@@ -124,70 +126,71 @@ static inline void GetDsoHandleFromSearchPath(const std::string& search_root,
"using the DYLD_LIBRARY_PATH is impossible unless System " "using the DYLD_LIBRARY_PATH is impossible unless System "
"Integrity Protection (SIP) is disabled."; "Integrity Protection (SIP) is disabled.";
if (throw_on_error) { if (throw_on_error) {
PADDLE_ENFORCE(nullptr != *dso_handle, error_msg, dlPath, dlerror()); PADDLE_ENFORCE(nullptr != dso_handle, error_msg, dlPath, dlerror());
} else if (nullptr == *dso_handle) { } else if (nullptr == dso_handle) {
LOG(WARNING) << string::Sprintf(error_msg, dlPath, dlerror()); LOG(WARNING) << string::Sprintf(error_msg, dlPath, dlerror());
} }
return dso_handle;
} }
void GetCublasDsoHandle(void** dso_handle) { void* GetCublasDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib");
#else #else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so");
#endif #endif
} }
void GetCUDNNDsoHandle(void** dso_handle) { void* GetCUDNNDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle, return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", false);
false);
#else #else
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle, false); return GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", false);
#endif #endif
} }
void GetCUPTIDsoHandle(void** dso_handle) { void* GetCUPTIDsoHandle() {
std::string cupti_path = cupti_lib_path; std::string cupti_path = cupti_lib_path;
if (!FLAGS_cupti_dir.empty()) { if (!FLAGS_cupti_dir.empty()) {
cupti_path = FLAGS_cupti_dir; cupti_path = FLAGS_cupti_dir;
} }
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(cupti_path, "libcupti.dylib", dso_handle, false); return GetDsoHandleFromSearchPath(cupti_path, "libcupti.dylib", false);
#else #else
GetDsoHandleFromSearchPath(cupti_path, "libcupti.so", dso_handle, false); return GetDsoHandleFromSearchPath(cupti_path, "libcupti.so", false);
#endif #endif
} }
void GetCurandDsoHandle(void** dso_handle) { void* GetCurandDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib");
#else #else
GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so");
#endif #endif
} }
void GetWarpCTCDsoHandle(void** dso_handle) { void* GetWarpCTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib");
#else #else
GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so");
#endif #endif
} }
void GetLapackDsoHandle(void** dso_handle) { void* GetLapackDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib");
#else #else
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so");
#endif #endif
} }
void GetNCCLDsoHandle(void** dso_handle) { void* GetNCCLDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib");
#else #else
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so", dso_handle); return GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so");
#endif #endif
} }
......
...@@ -18,55 +18,13 @@ namespace paddle { ...@@ -18,55 +18,13 @@ namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
/** void* GetCublasDsoHandle();
* @brief load the DSO of CUBLAS void* GetCUDNNDsoHandle();
* void* GetCUPTIDsoHandle();
* @param **dso_handle dso handler void* GetCurandDsoHandle();
* void* GetWarpCTCDsoHandle();
*/ void* GetLapackDsoHandle();
void GetCublasDsoHandle(void** dso_handle); void* GetNCCLDsoHandle();
/**
* @brief load the DSO of CUDNN
*
* @param **dso_handle dso handler
*
*/
void GetCUDNNDsoHandle(void** dso_handle);
void GetCUPTIDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CURAND
*
* @param **dso_handle dso handler
*
*/
void GetCurandDsoHandle(void** dso_handle);
/**
* @brief load the DSO of warp-ctc
*
* @param **dso_handle dso handler
*
*/
void GetWarpCTCDsoHandle(void** dso_handle);
/**
* @brief load the DSO of lapack
*
* @param **dso_handle dso handler
*
*/
void GetLapackDsoHandle(void** dso_handle);
/**
* @brief load the DSO of NVIDIA nccl
*
* @param **dso_handle dso handler
*
*/
void GetNCCLDsoHandle(void** dso_handle);
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
...@@ -25,11 +25,6 @@ void *nccl_dso_handle; ...@@ -25,11 +25,6 @@ void *nccl_dso_handle;
NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
void LoadNCCLDSO() {
platform::call_once(nccl_dso_flag,
[] { GetNCCLDsoHandle(&nccl_dso_handle); });
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h> #include <dlfcn.h>
#include <nccl.h> #include <nccl.h>
#include <mutex>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/call_once.h" #include "paddle/fluid/platform/call_once.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
...@@ -28,14 +29,15 @@ extern std::once_flag nccl_dso_flag; ...@@ -28,14 +29,15 @@ extern std::once_flag nccl_dso_flag;
extern void* nccl_dso_handle; extern void* nccl_dso_handle;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
extern void LoadNCCLDSO();
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \ using nccl_func = decltype(__name(args...)) (*)(Args...); \
paddle::platform::dynload::LoadNCCLDSO(); \ std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \ void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \ return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \ } \
......
...@@ -15,9 +15,10 @@ limitations under the License. */ ...@@ -15,9 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h> #include <dlfcn.h>
#include <mutex> #include <mutex> // NOLINT
#include "ctc.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "warpctc/include/ctc.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -36,9 +37,9 @@ extern void* warpctc_dso_handle; ...@@ -36,9 +37,9 @@ extern void* warpctc_dso_handle;
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \ using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
std::call_once(warpctc_dso_flag, \ std::call_once(warpctc_dso_flag, []() { \
paddle::platform::dynload::GetWarpCTCDsoHandle, \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
&warpctc_dso_handle); \ }); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \ return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \ } \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册