提交 57a82568 编写于 作者: J Juncheng 提交者: Jinhui Yuan

cmake for nccl (#1262)

上级 4b2c4ef0
......@@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 3.5)
option(BUILD_THIRD_PARTY "Build third party or oneflow" OFF)
option(BUILD_RDMA "" ON)
option(BUILD_CUDA "" ON)
option(BUILD_NCCL "" ON)
option(RELEASE_VERSION "" OFF)
# Project
......
......@@ -31,7 +31,14 @@ if (BUILD_CUDA)
list(APPEND CUDA_LIBRARIES ${cuda_lib_dir}/${extra_cuda_lib})
endforeach()
find_package(CuDNN REQUIRED)
find_package(NCCL REQUIRED)
if (BUILD_NCCL)
find_package(NCCL REQUIRED)
if (NCCL_VERSION VERSION_LESS 2.0)
message(FATAL_ERROR "minimum nccl version required is 2.0")
else()
add_definitions(-DWITH_NCCL)
endif()
endif()
endif()
if (NOT WIN32)
......
......@@ -6,10 +6,14 @@
namespace oneflow {
void NcclActor::InitDeviceCtx(const ThreadCtx& thread_ctx) {
#ifdef WITH_NCCL
CHECK_EQ(GetDeviceType(), DeviceType::kGPU);
// CHECK_EQ(GetLocalWorkStreamId(), 0);
mut_device_ctx().reset(new NcclDeviceCtx(
thread_ctx.g_cuda_stream.get(), Global<NcclCommMgr>::Get()->NcclComm4ActorId(actor_id())));
#else
UNIMPLEMENTED();
#endif // WITH_NCCL
}
REGISTER_ACTOR(TaskType::kNcclAllReduce, NcclActor);
......
......@@ -11,7 +11,6 @@
#include <cuda_runtime.h>
#include <cudnn.h>
#include <curand.h>
#include <nccl.h>
namespace oneflow {
......
......@@ -3,6 +3,10 @@
#include "oneflow/core/device/cuda_util.h"
#ifdef WITH_NCCL
#include <nccl.h>
#endif // WITH_NCCL
namespace oneflow {
class DeviceCtx {
......@@ -15,7 +19,9 @@ class DeviceCtx {
virtual const cublasHandle_t& cublas_pmh_handle() const { UNIMPLEMENTED(); }
virtual const cublasHandle_t& cublas_pmd_handle() const { UNIMPLEMENTED(); }
virtual const cudnnHandle_t& cudnn_handle() const { UNIMPLEMENTED(); }
#ifdef WITH_NCCL
virtual const ncclComm_t& nccl_handle() const { UNIMPLEMENTED(); }
#endif // WITH_NCCL
#endif
virtual void AddCallBack(std::function<void()>) const = 0;
......
......@@ -6,7 +6,7 @@
namespace oneflow {
#ifdef WITH_CUDA
#ifdef WITH_NCCL
class NcclDeviceCtx final : public CudaDeviceCtx {
public:
......@@ -21,7 +21,7 @@ class NcclDeviceCtx final : public CudaDeviceCtx {
ncclComm_t nccl_handler_;
};
#endif // WITH_CUDA
#endif // WITH_NCCL
} // namespace oneflow
......
......@@ -2,6 +2,8 @@
namespace oneflow {
#ifdef WITH_NCCL
void NcclCheck(ncclResult_t error) { CHECK_EQ(error, ncclSuccess) << ncclGetErrorString(error); }
#endif // WITH_NCCL
} // namespace oneflow
#ifndef ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_
#define ONEFLOW_CORE_DEVICE_NCCL_UTIL_H_
#include <nccl.h>
#include "oneflow/core/register/blob.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/util.h"
#ifdef WITH_NCCL
#include <nccl.h>
#endif // WITH_NCCL
namespace oneflow {
#ifdef WITH_NCCL
inline ncclDataType_t GetNcclDataType(const DataType& dt) {
switch (dt) {
#define NCCL_DATA_TYPE_CASE(dtype) \
......@@ -22,29 +26,45 @@ inline ncclDataType_t GetNcclDataType(const DataType& dt) {
}
void NcclCheck(ncclResult_t error);
#endif // WITH_NCCL
class NcclUtil final {
public:
using NcclReduceMthd = void(DeviceCtx*, Blob*, Blob*);
static void AllReduce(DeviceCtx* ctx, Blob* send_blob, Blob* recv_blob) {
#ifdef WITH_NCCL
auto elem_cnt = (size_t)send_blob->shape().elem_cnt();
NcclCheck(ncclAllReduce(send_blob->dptr(), recv_blob->mut_dptr(), elem_cnt,
GetNcclDataType(send_blob->data_type()), ncclSum, ctx->nccl_handle(),
ctx->cuda_stream()));
#else
UNIMPLEMENTED();
#endif // WITH_NCCL
}
static void ReduceScatter(DeviceCtx* ctx, Blob* send_blob, Blob* recv_blob) {
#ifdef WITH_NCCL
auto elem_cnt = (size_t)recv_blob->shape().elem_cnt();
NcclCheck(ncclReduceScatter(send_blob->dptr(), recv_blob->mut_dptr(), elem_cnt,
GetNcclDataType(send_blob->data_type()), ncclSum,
ctx->nccl_handle(), ctx->cuda_stream()));
#else
UNIMPLEMENTED();
#endif // WITH_NCCL
}
static void AllGather(DeviceCtx* ctx, Blob* send_blob, Blob* recv_blob) {
#ifdef WITH_NCCL
auto elem_cnt = (size_t)send_blob->shape().elem_cnt();
NcclCheck(ncclAllGather(send_blob->dptr(), recv_blob->mut_dptr(), elem_cnt,
GetNcclDataType(send_blob->data_type()), ctx->nccl_handle(),
ctx->cuda_stream()));
#else
UNIMPLEMENTED();
#endif // WITH_NCCL
}
};
......
......@@ -113,6 +113,9 @@ JobDesc::JobDesc(const std::string& job_conf_filepath) {
#ifndef WITH_RDMA
CHECK_EQ(job_conf_.other().use_rdma(), false) << "Please compile ONEFLOW with RDMA";
#endif
#ifndef WITH_NCCL
CHECK_EQ(job_conf_.other().enable_nccl(), false) << "Please compile ONEFLOW with NCCL";
#endif // WITH_NCCL
int64_t piece_exp = job_conf_.other().piece_num_of_experiment_phase();
if (job_conf_.other().has_train_conf()) {
TrainConf* train_conf = job_conf_.mutable_other()->mutable_train_conf();
......
......@@ -5,6 +5,8 @@
#include "oneflow/core/device/nccl_util.h"
#include "nccl_comm_manager.h"
#ifdef WITH_NCCL
namespace oneflow {
NcclCommMgr::NcclCommMgr(const Plan& plan) {
......@@ -108,3 +110,5 @@ void NcclCommMgr::NcclGetUniqueId4Tasks(const std::vector<TaskProto>& tasks,
}
} // namespace oneflow
#endif // WITH_NCCL
......@@ -3,6 +3,9 @@
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/plan.pb.h"
#ifdef WITH_NCCL
#include <nccl.h>
namespace oneflow {
......@@ -29,4 +32,6 @@ class NcclCommMgr final {
} // namespace oneflow
#endif // WITH_NCCL
#endif // ONEFLOW_CORE_JOB_NCCL_COMM_MANAGER_H_
......@@ -112,11 +112,15 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) {
Global<RegstMgr>::New(plan);
Global<ActorMsgBus>::New();
Global<ThreadMgr>::New(plan);
#ifdef WITH_NCCL
Global<NcclCommMgr>::New(plan);
#endif // WITH_NCCL
}
void Runtime::DeleteAllGlobal() {
#ifdef WITH_NCCL
Global<NcclCommMgr>::Delete();
#endif // WITH_NCCL
Global<ThreadMgr>::Delete();
Global<ActorMsgBus>::Delete();
Global<RegstMgr>::Delete();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册