提交 43c6ff21 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/nccl dso (#5001)

* "add nccl enforce"

* Dev

* Update comment

* Add nccl test

* Follow comments
上级 fcd74e06
......@@ -129,6 +129,7 @@ include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(cudnn) # set cudnn libraries, must before configure
include(nccl) # set nccl libraries
include(configure) # add paddle env configuration
include(generic) # simplify cmake module
include(package) # set paddle packages
......@@ -159,7 +160,7 @@ set(EXTERNAL_LIBS
if(WITH_GPU)
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
endif(NOT WITH_DSO)
endif(WITH_GPU)
......
......@@ -62,12 +62,19 @@ else()
FIND_PACKAGE(CUDA REQUIRED)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 7)
message(FATAL_ERROR "Paddle need CUDA >= 7.0 to compile")
message(FATAL_ERROR "Paddle needs CUDA >= 7.0 to compile")
endif()
if(NOT CUDNN_FOUND)
message(FATAL_ERROR "Paddle need cudnn to compile")
message(FATAL_ERROR "Paddle needs cudnn to compile")
endif()
if (NOT NCCL_INCLUDE_DIR)
message(FATAL_ERROR "Paddle needs nccl header to compile")
endif()
if (NOT WITH_DSO AND NOT NCCL_LIBRARY)
message(FATAL_ERROR "Paddle needs nccl libraries when WITH_DSO=OFF")
endif()
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}")
......
if (NOT WITH_GPU)
return ()
endif()
set(NCCL_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h PATHS
${NCCL_ROOT} ${NCCL_ROOT}/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
NO_DEFAULT_PATH)
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
set(TARGET_ARCH "x86_64")
if(NOT ${CMAKE_SYSTEM_PROCESSOR})
set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
endif()
list(APPEND NCCL_CHECK_LIBRARY_DIRS
${NCCL_ROOT}
${NCCL_ROOT}/lib64
${NCCL_ROOT}/lib
${NCCL_ROOT}/lib/${TARGET_ARCH}-linux-gnu
$ENV{NCCL_ROOT}
$ENV{NCCL_ROOT}/lib64
$ENV{NCCL_ROOT}/lib
/usr/lib)
find_library(NCCL_LIBRARY NAMES libnccl.so libnccl.dylib # libcudnn_static.a
PATHS ${NCCL_CHECK_LIBRARY_DIRS} ${NCCL_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to nccl library.")
......@@ -25,3 +25,4 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context)
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc DEPS dynamic_loader)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader)
......@@ -35,6 +35,11 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
DEFINE_string(nccl_dir, "",
"Specify path for loading nccl library, such as libcublas, "
"libcurand. For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH");
namespace paddle {
namespace platform {
namespace dynload {
......@@ -157,6 +162,14 @@ void GetLapackDsoHandle(void** dso_handle) {
#endif
}
void GetNCCLDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so", dso_handle);
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -58,6 +58,14 @@ void GetWarpCTCDsoHandle(void** dso_handle);
*/
void GetLapackDsoHandle(void** dso_handle);
/**
* @brief load the DSO of NVIDIA nccl
*
* @param **dso_handle dso handler
*
*/
void GetNCCLDsoHandle(void** dso_handle);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/dynload/nccl.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag nccl_dso_flag;
void *nccl_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <dlfcn.h>
#include <nccl.h>
#include <mutex>
#include "paddle/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag nccl_dso_flag;
extern void* nccl_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using nccl_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(nccl_dso_flag, \
paddle::platform::dynload::GetNCCLDsoHandle, \
&nccl_dso_handle); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
ncclResult_t operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define NCCL_RAND_ROUTINE_EACH(__macro) \
__macro(ncclCommInitAll); \
__macro(ncclGetUniqueId); \
__macro(ncclCommInitRank); \
__macro(ncclCommDestroy); \
__macro(ncclCommCount); \
__macro(ncclCommCuDevice); \
__macro(ncclCommUserRank); \
__macro(ncclAllReduce); \
__macro(ncclBcast); \
__macro(ncclAllGather); \
__macro(ncclReduce); \
__macro(ncclGetErrorString);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -29,11 +29,14 @@ limitations under the License. */
#include <cxxabi.h> // for __cxa_demangle
#endif
#include <glog/logging.h>
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/dynload/nccl.h"
#include <cublas_v2.h>
#include <cudnn.h>
......@@ -172,6 +175,17 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
throw std::runtime_error(err + string::Sprintf(args...));
}
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
ncclResult_t stat, const Args&... args) {
if (stat == ncclSuccess) {
return;
} else {
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
string::Sprintf(args...));
}
}
#endif // PADDLE_ONLY_CPU
template <typename T>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/dynload/nccl.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
static int dev_count = 0;
namespace paddle {
namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
template <typename T>
struct PerThreadData {
thrust::device_vector<T> send_buff;
thrust::device_vector<T> recv_buff;
CUDADeviceContext dev_ctx;
T* SendBuff() { return thrust::raw_pointer_cast(send_buff.data()); }
T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); }
PerThreadData(int gpu_id, size_t size) : dev_ctx(GPUPlace(gpu_id)) {
send_buff.resize(size);
for (size_t i = 0; i < size; ++i) {
send_buff[i] = static_cast<T>(i);
}
recv_buff.resize(size);
}
};
static constexpr int ELEM_COUNT = 10000;
TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm";
auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
PADDLE_ENFORCE(status);
VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data;
data.reserve(dev_count);
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Creating thread data for device " << i;
SetDeviceId(i);
data.emplace_back(new PerThreadData<double>(i, ELEM_COUNT));
}
VLOG(1) << "Thread data created";
VLOG(1) << "Check send_buf data";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Check on device " << i;
SetDeviceId(i);
thrust::host_vector<double> tmp = data[i]->send_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
ASSERT_NEAR(static_cast<double>(j), tmp[j], 1e-5);
}
}
VLOG(1) << "Invoking ncclAllReduce";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Invoking ncclAllReduce with device " << i;
SetDeviceId(i);
PADDLE_ENFORCE(dynload::ncclAllReduce(
data[i]->SendBuff(), data[i]->RecvBuff(), ELEM_COUNT, ncclDouble,
ncclSum, comms[i], data[i]->dev_ctx.stream()));
VLOG(1) << "Invoked ncclAllReduce for device " << i;
}
VLOG(1) << "Invoked ncclAllReduce";
VLOG(1) << "Sync devices";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Sync device " << i;
SetDeviceId(i);
data[i]->dev_ctx.Wait();
}
VLOG(1) << "device synced";
for (int i = 0; i < dev_count; ++i) {
SetDeviceId(i);
VLOG(1) << "Checking vector on device " << i;
thrust::host_vector<double> tmp = data[i]->recv_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
auto elem = static_cast<double>(j);
elem *= dev_count;
ASSERT_NEAR(tmp[j], elem, 1e-4);
}
}
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
} // namespace platform
} // namespace paddle
int main(int argc, char** argv) {
dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -35,6 +35,7 @@ struct GPUPlace {
GPUPlace() : GPUPlace(0) {}
explicit GPUPlace(int d) : device(d) {}
inline int GetDeviceId() const { return device; }
// needed for variant equality comparison
inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册