From 0bc74f28c56280114b2f5d5e2b9c35904863f335 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sun, 19 Jul 2020 10:54:47 +0800 Subject: [PATCH] Enable get rank id and size by group --- mindspore/ccsrc/CMakeLists.txt | 3 + .../gpu/math/broadcast_gpu_kernel.cc | 6 ++ .../gpu/math/broadcast_gpu_kernel.h | 7 ++- .../gpu/nccl/nccl_gpu_kernel.cc | 11 ++++ .../ccsrc/frontend/parallel/group_manager.cc | 4 +- mindspore/ccsrc/runtime/device/CMakeLists.txt | 1 + .../gpu/distribution/collective_common.h | 7 +++ .../gpu/distribution/collective_wrapper.cc | 55 ++++++------------- .../gpu/distribution/collective_wrapper.h | 47 ++++++++++++++++ .../device/gpu/distribution/mpi_wrapper.cc | 12 ++-- .../device/gpu/distribution/nccl_wrapper.cc | 50 ++++++++--------- .../device/gpu/distribution/nccl_wrapper.h | 14 ++--- .../runtime/device/gpu/mpi/mpi_initializer.cc | 29 ++-------- .../runtime/device/gpu/mpi/mpi_initializer.h | 14 ++--- mindspore/communication/_comm_helper.py | 10 +--- mindspore/ops/_op_impl/akg/gpu/equal.py | 1 + 16 files changed, 145 insertions(+), 126 deletions(-) create mode 100644 mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 34766f443..04cc9f092 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -279,6 +279,9 @@ if (ENABLE_GPU) ${CUDNN_PATH}/lib64/libcudnn.so ${CUDA_PATH}/lib64/libcudart.so ${CUDA_PATH}/lib64/stubs/libcuda.so) + if (ENABLE_MPI) + set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH}) + endif() endif () if (ENABLE_CPU) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 41e714732..2881cb125 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -99,5 +99,11 @@ MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO( Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 8dc364db9..b131aef58 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -96,9 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, + {"TensorAdd", BROADCAST_TYPE_ADD}, }; auto iter = kBroadcastTypeMap.find(kernel_name); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc index c6e3c4c04..8374914dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -24,17 +24,28 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AllReduce, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) + MS_REG_GPU_KERNEL_ONE( AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), NcclGpuKernel, float) MS_REG_GPU_KERNEL_ONE( AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AllGather, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) + MS_REG_GPU_KERNEL_ONE( ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), NcclGpuKernel, float) MS_REG_GPU_KERNEL_ONE( ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceScatter, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclGpuKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc index 93855cd52..98fca25b3 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -70,9 +70,7 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto mindspore::parallel::Group *const group) { // it is simple to use size to determine whether it is a world group uint32_t world_size = 0; - if (world_group_ != NCCL_WORLD_GROUP) { - (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); - } + (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); if (devices.size() == world_size) { auto it = groups_.find(world_group_); diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt index 9c95aee0d..3de02822e 100644 --- a/mindspore/ccsrc/runtime/device/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -55,6 +55,7 @@ if (ENABLE_GPU) PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) + target_link_libraries(_ms_mpi PRIVATE gpu_collective) endif () # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h index 394aaf310..71018e8c7 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#include #include #include "pybind11/pybind11.h" @@ -25,6 +26,12 @@ namespace device { namespace gpu { constexpr int MAX_HOSTNAME_LEN = 1024; constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; +struct NcclGroupInfo { + int size; + int rank; + ncclUniqueId unique_id; + ncclComm_t comm; +}; #define CHECK_RET(expression, result, message) \ { \ auto ret = (expression); \ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc index f427905af..d74f1ebea 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc @@ -14,58 +14,37 @@ * limitations under the License. */ -#include -#include -#include -#include #include -#include #include -#include "runtime/device/gpu/distribution/mpi_wrapper.h" -#include "runtime/device/gpu/distribution/nccl_wrapper.h" +#include "runtime/device/gpu/distribution/collective_wrapper.h" -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif +void InitMPI() { MPIWrapper::instance(); } -using MPIWrapper = mindspore::device::gpu::MPIWrapper; -using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; +int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } -extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } +void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } -extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } - -extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } - -extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector &ranks) { +bool CreateCommGroup(const std::string &group_name, const std::vector &ranks) { return MPIWrapper::instance().CreateCommGroup(group_name, ranks); } -extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) { - return MPIWrapper::instance().GetRankIDByGroup(group_name); -} +int GetRankIDByGroup(const std::string &group_name) { return MPIWrapper::instance().GetRankIDByGroup(group_name); } -extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) { - return MPIWrapper::instance().GetGroupSize(group_name); -} +int GetGroupSize(const std::string &group_name) { return MPIWrapper::instance().GetGroupSize(group_name); } -extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) { - return MPIWrapper::instance().DestroyGroup(group_name); -} +bool DestroyGroup(const std::string &group_name) { return MPIWrapper::instance().DestroyGroup(group_name); } -extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); +ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream, group); } -extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, cudaStream_t stream) { - return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); +ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream, group); } -extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); +ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { + return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group); } diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h new file mode 100644 index 000000000..e76ede4d3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/mpi_wrapper.h" +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +using MPIWrapper = mindspore::device::gpu::MPIWrapper; +using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; + +extern "C" EXPORT_WRAPPER void InitMPI(); +extern "C" EXPORT_WRAPPER int local_rank_id(); +extern "C" EXPORT_WRAPPER void InitNCCLComm(); +extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector &ranks); +extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name); +extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name); +extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name); +extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, + const std::string &group); +extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, cudaStream_t stream, + const std::string &group); +extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream, const std::string &group); diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc index 08ec320ca..aae35d6c1 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc @@ -58,7 +58,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto if (rank_id_ == ranks[0]) { group_unique_id = NCCLWrapper::instance().nccl_unique_id(); } - MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm); + MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, 0, mpi_group_comm); int group_rank[1]; int global_rank[1] = {rank_id_}; @@ -68,9 +68,8 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto return false; } - ncclComm_t nccl_group_comm; - NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]); - NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm); + NcclGroupInfo nccl_group = {static_cast(ranks.size()), group_rank[0], group_unique_id, nullptr}; + NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group); return true; } @@ -111,7 +110,6 @@ void MPIWrapper::Init() { CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); - NCCLWrapper::instance().set_rank(rank_id_, rank_size_); AssignLocalRankID(); CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD"); @@ -123,7 +121,9 @@ void MPIWrapper::Init() { } CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), MPI_SUCCESS, "Failed to broadcast nccl unique id."); - NCCLWrapper::instance().set_nccl_unique_id(unique_id); + + NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr}; + NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group); return; } diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc index bcba53830..519a29a59 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc @@ -30,60 +30,58 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const { return unique_id; } -void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } - -void NCCLWrapper::set_rank(int rank_id, int rank_size) { - rank_id_ = rank_id; - rank_size_ = rank_size; -} - void NCCLWrapper::InitNCCLComm() { - CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, - "Failed to init nccl communicator."); - group_to_comm_map_[NCCL_WORLD_GROUP] = comm_; -} - -void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) { - CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator."); + for (auto group : group_info_) { + std::string group_name = group.first; + NcclGroupInfo group_info = group.second; + CHECK_RET(ncclCommInitRank(&(group_info.comm), group_info.size, group_info.unique_id, group_info.rank), ncclSuccess, + "Failed to init nccl communicator for group " + group_name); + group_info_[group_name].comm = group_info.comm; + } + comm_init_done_ = true; } ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for AllReduce by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); } ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for AllGather by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream); } ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { - CHECK_RET(group_to_comm_map_.count(group_name), 1, + CHECK_RET(group_info_.count(group_name), 1, "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name); - ncclComm_t group_comm = group_to_comm_map_[group_name]; + ncclComm_t group_comm = group_info_[group_name].comm; return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); } -void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) { - group_to_comm_map_[group_name] = comm; +void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) { + if (comm_init_done_) { + CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess, + "Failed to init nccl communicator for group " + group_name); + } + group_info_[group_name] = *group; } void NCCLWrapper::DestroyGroup(const std::string &group_name) { - auto group_iter = group_to_comm_map_.find(group_name); - if (group_iter == group_to_comm_map_.end()) { + auto group_iter = group_info_.find(group_name); + if (group_iter == group_info_.end()) { return; } - group_to_comm_map_.erase(group_iter); - ncclComm_t group_comm = group_iter->second; + ncclComm_t group_comm = group_iter->second.comm; CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name); + group_info_.erase(group_iter); return; } } // namespace gpu diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h index 0dc8f790f..94525ebe4 100644 --- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h @@ -33,29 +33,23 @@ class NCCLWrapper { NCCLWrapper &operator=(const NCCLWrapper &) = delete; static NCCLWrapper &instance(); ncclUniqueId nccl_unique_id() const; - void set_nccl_unique_id(ncclUniqueId unique_id); - void set_rank(int rank_id, int rank_size); void InitNCCLComm(); - void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank); ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); - void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm); + void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group); void DestroyGroup(const std::string &group_name); private: - NCCLWrapper() : rank_id_(-1), rank_size_(0) {} + NCCLWrapper() : comm_init_done_(false) {} ~NCCLWrapper() = default; private: - int rank_id_; - int rank_size_; - ncclUniqueId unique_id_; - ncclComm_t comm_; - std::map group_to_comm_map_; + bool comm_init_done_; + std::map group_info_; }; } // namespace gpu } // namespace device diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc index 4605a0eb4..d34bc5530 100644 --- a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc @@ -15,45 +15,24 @@ */ #include "runtime/device/gpu/mpi/mpi_initializer.h" - +#include #include #include #include +#include namespace mindspore { namespace device { namespace gpu { -MPIInitializer::MPIInitializer() { - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - return; - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - return; - } - } - MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); -} - -MPIInitializer::~MPIInitializer() { - int finalized_flag = 0; - (void)MPI_Finalized(&finalized_flag); - if (finalized_flag == 0) { - (void)MPI_Finalize(); - } -} MPIInitializer &MPIInitializer::GetInstance() { static MPIInitializer instance; return instance; } -int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } +int MPIInitializer::get_rank_id(const std::string &group) { return GetRankIDByGroup(group); } -int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } +int MPIInitializer::get_rank_size(const std::string &group) { return GetGroupSize(group); } PYBIND11_MODULE(_ms_mpi, mpi_initializer) { mpi_initializer.doc() = "mindspore mpi python wrapper"; diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h index fc4ad7468..20b0a4fba 100644 --- a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h @@ -17,6 +17,9 @@ #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ +#include +#include "runtime/device/gpu/distribution/collective_wrapper.h" + namespace mindspore { namespace device { namespace gpu { @@ -25,15 +28,12 @@ class MPIInitializer { MPIInitializer(MPIInitializer const &) = delete; MPIInitializer &operator=(const MPIInitializer &) = delete; static MPIInitializer &GetInstance(); - static int get_rank_id(); - static int get_rank_size(); + static int get_rank_id(const std::string &group); + static int get_rank_size(const std::string &groups); private: - MPIInitializer(); - ~MPIInitializer(); - - int rank_id_; - int rank_size_; + MPIInitializer() = default; + ~MPIInitializer() = default; }; } // namespace gpu } // namespace device diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 1723fe9c9..920488cee 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -163,10 +163,7 @@ def _get_rank_helper(group, backend): else: rank_id = hccl.get_rank_id(group) elif backend == Backend.NCCL: - if group == NCCL_WORLD_COMM_GROUP: - rank_id = mpi.get_rank_id() - else: - raise RuntimeError("Nccl doesn't support get_rank_id by user group now.") + rank_id = mpi.get_rank_id(group) else: raise ValueError("Invalid backend: '{}'".format(backend)) return rank_id @@ -225,10 +222,7 @@ def _get_size_helper(group, backend): else: size = hccl.get_rank_size(group) elif backend == Backend.NCCL: - if group == NCCL_WORLD_COMM_GROUP: - size = mpi.get_rank_size() - else: - raise RuntimeError("Nccl doesn't support get_rank_size by user group now.") + size = mpi.get_rank_size(group) else: raise ValueError("Invalid backend: '{}'".format(backend)) return size diff --git a/mindspore/ops/_op_impl/akg/gpu/equal.py b/mindspore/ops/_op_impl/akg/gpu/equal.py index 40a3590f6..c63988f20 100644 --- a/mindspore/ops/_op_impl/akg/gpu/equal.py +++ b/mindspore/ops/_op_impl/akg/gpu/equal.py @@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \ .output(0, "output") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ .get_op_info() -- GitLab