提交 ab23776f 编写于 作者: Z ZPaC

GPU supports to create groups for auto parallel.

上级 3bb04abc
...@@ -57,6 +57,7 @@ if(ENABLE_GPU) ...@@ -57,6 +57,7 @@ if(ENABLE_GPU)
set_property(SOURCE ${GPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) set_property(SOURCE ${GPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
cuda_add_library(gpu_cuda_lib STATIC ${GPU_SRC_LIST}) cuda_add_library(gpu_cuda_lib STATIC ${GPU_SRC_LIST})
set(CMAKE_CXX_FLAGS ${NVCC_TMP_CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS ${NVCC_TMP_CMAKE_CXX_FLAGS})
add_compile_definitions(ENABLE_GPU)
endif () endif ()
## make flatuffer files ## make flatuffer files
......
...@@ -40,9 +40,11 @@ const std::map<std::string, NcclKernelType> kNcclTypeMap = { ...@@ -40,9 +40,11 @@ const std::map<std::string, NcclKernelType> kNcclTypeMap = {
static std::map<std::string, ncclDataType_t> kNcclDtypeMap = { static std::map<std::string, ncclDataType_t> kNcclDtypeMap = {
{"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}};
typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t,
typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t); const std::string &);
typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t, const std::string &);
typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t,
const std::string &);
template <typename T> template <typename T>
class NcclGpuKernel : public GpuKernel { class NcclGpuKernel : public GpuKernel {
...@@ -50,6 +52,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -50,6 +52,7 @@ class NcclGpuKernel : public GpuKernel {
NcclGpuKernel() NcclGpuKernel()
: nccl_kernel_type_(NCCL_INVALID_TYPE), : nccl_kernel_type_(NCCL_INVALID_TYPE),
nccl_reduce_type_(ncclSum), nccl_reduce_type_(ncclSum),
group_name_(""),
input_size_(0), input_size_(0),
output_size_(0), output_size_(0),
collective_handle_(nullptr), collective_handle_(nullptr),
...@@ -71,7 +74,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -71,7 +74,7 @@ class NcclGpuKernel : public GpuKernel {
reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce")); reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
MS_EXCEPTION_IF_NULL(all_reduce_funcptr); MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, nccl_reduce_type_, stream), nccl_data_type_, nccl_reduce_type_, stream, group_name_),
"ncclAllReduce failed"); "ncclAllReduce failed");
break; break;
} }
...@@ -80,7 +83,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -80,7 +83,7 @@ class NcclGpuKernel : public GpuKernel {
reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather")); reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
MS_EXCEPTION_IF_NULL(all_gather_funcptr); MS_EXCEPTION_IF_NULL(all_gather_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT( CHECK_NCCL_RET_WITH_EXCEPT(
(*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_),
"ncclAllGather failed"); "ncclAllGather failed");
break; break;
} }
...@@ -89,7 +92,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -89,7 +92,7 @@ class NcclGpuKernel : public GpuKernel {
reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter")); reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, nccl_reduce_type_, stream), nccl_data_type_, nccl_reduce_type_, stream, group_name_),
"ncclReduceScatter failed"); "ncclReduceScatter failed");
break; break;
} }
...@@ -121,15 +124,18 @@ class NcclGpuKernel : public GpuKernel { ...@@ -121,15 +124,18 @@ class NcclGpuKernel : public GpuKernel {
output_size_list_.push_back(size); output_size_list_.push_back(size);
output_size_ += size; output_size_ += size;
} }
InferCommType(kernel_node);
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
MS_EXCEPTION_IF_NULL(collective_handle_);
InferCommType(kernel_node);
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id");
if (comm_stream_attr) { if (comm_stream_attr) {
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr)); comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
MS_EXCEPTION_IF_NULL(comm_stream_); MS_EXCEPTION_IF_NULL(comm_stream_);
} }
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
MS_EXCEPTION_IF_NULL(collective_handle_);
return true; return true;
} }
...@@ -146,7 +152,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -146,7 +152,7 @@ class NcclGpuKernel : public GpuKernel {
nccl_kernel_type_ = iter->second; nccl_kernel_type_ = iter->second;
} }
auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp);
if (reduce_op) { if (reduce_op) {
std::string type = GetValue<std::string>(reduce_op); std::string type = GetValue<std::string>(reduce_op);
if (type == "sum") { if (type == "sum") {
...@@ -167,6 +173,7 @@ class NcclGpuKernel : public GpuKernel { ...@@ -167,6 +173,7 @@ class NcclGpuKernel : public GpuKernel {
NcclKernelType nccl_kernel_type_; NcclKernelType nccl_kernel_type_;
ncclRedOp_t nccl_reduce_type_; ncclRedOp_t nccl_reduce_type_;
ncclDataType_t nccl_data_type_; ncclDataType_t nccl_data_type_;
std::string group_name_;
size_t input_size_; size_t input_size_;
size_t output_size_; size_t output_size_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
......
...@@ -23,16 +23,17 @@ ...@@ -23,16 +23,17 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
#define MAX_HOSTNAME_LEN 1024 constexpr int MAX_HOSTNAME_LEN = 1024;
#define CHECK_RET(expression, result, message) \ constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
{ \ #define CHECK_RET(expression, result, message) \
auto ret = (expression); \ { \
if (ret != result) { \ auto ret = (expression); \
std::ostringstream oss; \ if (ret != result) { \
oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error " << message \ std::ostringstream oss; \
<< " | Error Number " << ret; \ oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error: " << message \
pybind11::pybind11_fail(oss.str()); \ << " | Error Number " << ret; \
} \ pybind11::pybind11_fail(oss.str()); \
} \
} }
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ #define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_
#include <dlfcn.h> #include <dlfcn.h>
#include <vector>
#include <string>
namespace mindspore { namespace mindspore {
namespace device { namespace device {
...@@ -25,6 +27,10 @@ namespace gpu { ...@@ -25,6 +27,10 @@ namespace gpu {
using InitMPI = void (*)(); using InitMPI = void (*)();
using InitNCCLComm = void (*)(); using InitNCCLComm = void (*)();
using GetLocalRankId = int (*)(); using GetLocalRankId = int (*)();
using CreateCommGroupFunc = bool (*)(const std::string &, const std::vector<unsigned int> &);
using GetRankIDByGroupFunc = int (*)(const std::string &);
using GetGroupSizeFunc = int (*)(const std::string &);
using DestroyGroupFunc = bool (*)(const std::string &);
class CollectiveInitializer { class CollectiveInitializer {
public: public:
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <vector>
#include "runtime/device/gpu/distribution/mpi_wrapper.h" #include "runtime/device/gpu/distribution/mpi_wrapper.h"
#include "runtime/device/gpu/distribution/nccl_wrapper.h" #include "runtime/device/gpu/distribution/nccl_wrapper.h"
...@@ -36,6 +37,22 @@ extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().lo ...@@ -36,6 +37,22 @@ extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().lo
extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
return MPIWrapper::instance().CreateCommGroup(group_name, ranks);
}
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) {
return MPIWrapper::instance().GetRankIDByGroup(group_name);
}
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) {
return MPIWrapper::instance().GetGroupSize(group_name);
}
extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) {
return MPIWrapper::instance().DestroyGroup(group_name);
}
extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, ncclRedOp_t reduce_type, ncclDataType_t data_type, ncclRedOp_t reduce_type,
cudaStream_t stream) { cudaStream_t stream) {
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
*/ */
#include "runtime/device/gpu/distribution/mpi_wrapper.h" #include "runtime/device/gpu/distribution/mpi_wrapper.h"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <string> #include <string>
#include <vector>
#include "runtime/device/gpu/distribution/nccl_wrapper.h" #include "runtime/device/gpu/distribution/nccl_wrapper.h"
namespace mindspore { namespace mindspore {
...@@ -40,17 +40,82 @@ MPIWrapper &MPIWrapper::instance() { ...@@ -40,17 +40,82 @@ MPIWrapper &MPIWrapper::instance() {
int MPIWrapper::local_rank_id() const { return local_rank_id_; } int MPIWrapper::local_rank_id() const { return local_rank_id_; }
bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &group_ranks) {
std::vector<int> ranks(group_ranks.begin(), group_ranks.end());
MPI_Group mpi_group;
CHECK_RET(MPI_Group_incl(world_group_, ranks.size(), ranks.data(), &mpi_group), MPI_SUCCESS,
"Failed to produce a new group from MPI_COMM_WORLD group for " + group_name);
SetGroupNameToMPIGroup(group_name, mpi_group);
MPI_Comm mpi_group_comm;
CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, mpi_group, &mpi_group_comm), MPI_SUCCESS,
"Failed to create MPI communicator.");
if (mpi_group_comm == MPI_COMM_NULL) {
return false;
}
ncclUniqueId group_unique_id;
if (rank_id_ == ranks[0]) {
group_unique_id = NCCLWrapper::instance().nccl_unique_id();
}
MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm);
int group_rank[1];
int global_rank[1] = {rank_id_};
CHECK_RET(MPI_Group_translate_ranks(world_group_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS,
"Failed to translate global rank to group rank.");
if (group_rank[0] == MPI_UNDEFINED) {
return false;
}
ncclComm_t nccl_group_comm;
NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]);
NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm);
return true;
}
int MPIWrapper::GetRankIDByGroup(const std::string &group_name) {
CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name " + group_name);
MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name];
int rank;
CHECK_RET(MPI_Group_rank(mpi_group, &rank), MPI_SUCCESS, "Failed to get rank id by group name." + group_name);
return rank;
}
int MPIWrapper::GetGroupSize(const std::string &group_name) {
CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name" + group_name);
MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name];
int size;
CHECK_RET(MPI_Group_size(mpi_group, &size), MPI_SUCCESS, "Failed to get group size by group name." + group_name);
return size;
}
bool MPIWrapper::DestroyGroup(const std::string &group_name) {
auto group_iter = group_name_to_mpi_group_map_.find(group_name);
if (group_iter == group_name_to_mpi_group_map_.end()) {
return false;
}
group_name_to_mpi_group_map_.erase(group_name);
MPI_Group mpi_group = group_iter->second;
CHECK_RET(MPI_Group_free(&mpi_group), MPI_SUCCESS, "Failed to free MPI group for " + group_name);
NCCLWrapper::instance().DestroyGroup(group_name);
return true;
}
void MPIWrapper::Init() { void MPIWrapper::Init() {
int initialized; int initialized;
CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status."); CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status.");
if (initialized == 0) { if (initialized == 0) {
MPI_Init(nullptr, nullptr); MPI_Init(nullptr, nullptr);
} }
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id.");
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size.");
NCCLWrapper::instance().set_rank(rank_id_, rank_size_); NCCLWrapper::instance().set_rank(rank_id_, rank_size_);
AssignLocalRankId(); AssignLocalRankID();
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD");
SetGroupNameToMPIGroup(NCCL_WORLD_GROUP, world_group_);
ncclUniqueId unique_id; ncclUniqueId unique_id;
if (rank_id_ == 0) { if (rank_id_ == 0) {
...@@ -62,7 +127,7 @@ void MPIWrapper::Init() { ...@@ -62,7 +127,7 @@ void MPIWrapper::Init() {
return; return;
} }
void MPIWrapper::AssignLocalRankId() { void MPIWrapper::AssignLocalRankID() {
char host_name[MAX_HOSTNAME_LEN] = {0}; char host_name[MAX_HOSTNAME_LEN] = {0};
CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed."); CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed.");
size_t host_hash = std::hash<std::string>()(host_name); size_t host_hash = std::hash<std::string>()(host_name);
...@@ -82,6 +147,10 @@ void MPIWrapper::AssignLocalRankId() { ...@@ -82,6 +147,10 @@ void MPIWrapper::AssignLocalRankId() {
} }
return; return;
} }
void MPIWrapper::SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group) {
group_name_to_mpi_group_map_[group_name] = mpi_group;
}
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include <unistd.h> #include <unistd.h>
#include <mpi.h> #include <mpi.h>
#include <iostream> #include <iostream>
#include <map>
#include <string>
#include <vector>
#include "runtime/device/gpu/distribution/collective_common.h" #include "runtime/device/gpu/distribution/collective_common.h"
namespace mindspore { namespace mindspore {
...@@ -33,16 +36,23 @@ class MPIWrapper { ...@@ -33,16 +36,23 @@ class MPIWrapper {
MPIWrapper &operator=(const MPIWrapper &) = delete; MPIWrapper &operator=(const MPIWrapper &) = delete;
static MPIWrapper &instance(); static MPIWrapper &instance();
int local_rank_id() const; int local_rank_id() const;
bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks);
int GetRankIDByGroup(const std::string &group_name);
int GetGroupSize(const std::string &group_name);
bool DestroyGroup(const std::string &group_name);
private: private:
MPIWrapper(); MPIWrapper();
~MPIWrapper(); ~MPIWrapper();
void Init(); void Init();
void AssignLocalRankId(); void AssignLocalRankID();
void SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group);
int rank_id_; int rank_id_;
int rank_size_; int rank_size_;
int local_rank_id_; int local_rank_id_;
MPI_Group world_group_;
std::map<std::string, MPI_Group> group_name_to_mpi_group_map_;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
......
...@@ -40,21 +40,51 @@ void NCCLWrapper::set_rank(int rank_id, int rank_size) { ...@@ -40,21 +40,51 @@ void NCCLWrapper::set_rank(int rank_id, int rank_size) {
void NCCLWrapper::InitNCCLComm() { void NCCLWrapper::InitNCCLComm() {
CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess,
"Failed to init nccl communicator."); "Failed to init nccl communicator.");
group_to_comm_map_[NCCL_WORLD_GROUP] = comm_;
}
void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) {
CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator.");
} }
ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
ncclRedOp_t reduce_type, cudaStream_t stream) { ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) {
return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); CHECK_RET(group_to_comm_map_.count(group_name), 1,
"Failed to find NCCL communicator for AllReduce by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name];
return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
} }
ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
cudaStream_t stream) { cudaStream_t stream, const std::string &group_name) {
return ncclAllGather(input_addr, output_addr, count, data_type, comm_, stream); CHECK_RET(group_to_comm_map_.count(group_name), 1,
"Failed to find NCCL communicator for AllGather by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name];
return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream);
} }
ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream) { ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); const std::string &group_name) {
CHECK_RET(group_to_comm_map_.count(group_name), 1,
"Failed to find NCCL communicator for ReduceScatter by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name];
return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
}
void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) {
group_to_comm_map_[group_name] = comm;
}
void NCCLWrapper::DestroyGroup(const std::string &group_name) {
auto group_iter = group_to_comm_map_.find(group_name);
if (group_iter == group_to_comm_map_.end()) {
return;
}
group_to_comm_map_.erase(group_iter);
ncclComm_t group_comm = group_iter->second;
CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name);
return;
} }
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <nccl.h> #include <nccl.h>
#include <string>
#include <map>
#include "runtime/device/gpu/distribution/collective_common.h" #include "runtime/device/gpu/distribution/collective_common.h"
namespace mindspore { namespace mindspore {
...@@ -34,12 +36,15 @@ class NCCLWrapper { ...@@ -34,12 +36,15 @@ class NCCLWrapper {
void set_nccl_unique_id(ncclUniqueId unique_id); void set_nccl_unique_id(ncclUniqueId unique_id);
void set_rank(int rank_id, int rank_size); void set_rank(int rank_id, int rank_size);
void InitNCCLComm(); void InitNCCLComm();
void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank);
ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, cudaStream_t stream); ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
cudaStream_t stream); cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, cudaStream_t stream); ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm);
void DestroyGroup(const std::string &group_name);
private: private:
NCCLWrapper() : rank_id_(-1), rank_size_(0) {} NCCLWrapper() : rank_id_(-1), rank_size_(0) {}
...@@ -50,6 +55,7 @@ class NCCLWrapper { ...@@ -50,6 +55,7 @@ class NCCLWrapper {
int rank_size_; int rank_size_;
ncclUniqueId unique_id_; ncclUniqueId unique_id_;
ncclComm_t comm_; ncclComm_t comm_;
std::map<std::string, ncclComm_t> group_to_comm_map_;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
......
...@@ -16,17 +16,27 @@ ...@@ -16,17 +16,27 @@
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#ifndef NO_DLIB #ifndef NO_DLIB
#include "hccl/hcom.h" #include "hccl/hcom.h"
#endif #endif
#if defined(ENABLE_GPU)
#include "runtime/device/gpu/distribution/collective_init.h"
using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer;
using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc;
using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc;
using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc;
using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc;
#endif
namespace mindspore { namespace mindspore {
#ifndef NO_DLIB
CommManager &CommManager::GetInstance() noexcept { CommManager &CommManager::GetInstance() noexcept {
static CommManager instance("hccl"); static CommManager instance("hccl");
return instance; return instance;
} }
#ifndef NO_DLIB
#define HCCL_RUN_CHECK(op_name, group, op) \ #define HCCL_RUN_CHECK(op_name, group, op) \
do { \ do { \
auto hccl_result = (op); \ auto hccl_result = (op); \
...@@ -79,7 +89,79 @@ bool CommManager::DestroyGroup(const string &group) const { ...@@ -79,7 +89,79 @@ bool CommManager::DestroyGroup(const string &group) const {
HCCL_RUN_CHECK(string("destroy communicate group"), group, hcom_destroy_group(group.c_str())); HCCL_RUN_CHECK(string("destroy communicate group"), group, hcom_destroy_group(group.c_str()));
return true; return true;
} }
#elif defined(ENABLE_GPU)
CommManager &CommManager::GetInstance() noexcept {
static CommManager instance("nccl");
return instance;
}
bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
auto create_comm_group_funcptr =
reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
if (!ret) {
MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
return ret;
}
return ret;
}
bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto get_rank_id_funcptr =
reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
int rank = (*get_rank_id_funcptr)(group);
*rank_id = static_cast<unsigned int>(rank);
MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
return true;
}
bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto get_group_size_funcptr =
reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
int size = (*get_group_size_funcptr)(group);
*rank_size = static_cast<unsigned int>(size);
MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
return true;
}
bool CommManager::DestroyGroup(const string &group) const {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
if (!collective_handle_) {
MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
}
auto destroy_group_funcptr =
reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
bool ret = (*destroy_group_funcptr)(group);
if (!ret) {
MS_LOG(ERROR) << "Destroying group " << group << " failed.";
return ret;
}
return ret;
}
#else #else
CommManager &CommManager::GetInstance() noexcept {
static CommManager instance("hccl");
return instance;
}
bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; } bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; }
bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册