diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt
index 53300acda4ab69b74095bfcd27a35257f0fe5edd..16f038698921fa8b4f5a4fd18b9bafcf40b72e1f 100644
--- a/mindspore/ccsrc/CMakeLists.txt
+++ b/mindspore/ccsrc/CMakeLists.txt
@@ -57,6 +57,7 @@ if(ENABLE_GPU)
     set_property(SOURCE ${GPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
     cuda_add_library(gpu_cuda_lib STATIC ${GPU_SRC_LIST})
     set(CMAKE_CXX_FLAGS ${NVCC_TMP_CMAKE_CXX_FLAGS})
+    add_compile_definitions(ENABLE_GPU)
 endif ()
 
 ## make flatuffer files
diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h
index 4c3c3189fbe85209b613a3fa4d37cb3515e563d0..9701738bfc7fb232fee7fc616216ad201c4df7b3 100644
--- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h
+++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h
@@ -40,9 +40,11 @@ const std::map<std::string, NcclKernelType> kNcclTypeMap = {
 static std::map<std::string, ncclDataType_t> kNcclDtypeMap = {
   {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}};
 
-typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t);
-typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t);
-typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t);
+typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t,
+                                  const std::string &);
+typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t, const std::string &);
+typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t,
+                                      const std::string &);
 
 template <typename T>
 class NcclGpuKernel : public GpuKernel {
@@ -50,6 +52,7 @@ class NcclGpuKernel : public GpuKernel {
   NcclGpuKernel()
       : nccl_kernel_type_(NCCL_INVALID_TYPE),
         nccl_reduce_type_(ncclSum),
+        group_name_(""),
         input_size_(0),
         output_size_(0),
         collective_handle_(nullptr),
@@ -71,7 +74,7 @@ class NcclGpuKernel : public GpuKernel {
           reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
         MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
         CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
-                                                         nccl_data_type_, nccl_reduce_type_, stream),
+                                                         nccl_data_type_, nccl_reduce_type_, stream, group_name_),
                                    "ncclAllReduce failed");
         break;
       }
@@ -80,7 +83,7 @@ class NcclGpuKernel : public GpuKernel {
           reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
         MS_EXCEPTION_IF_NULL(all_gather_funcptr);
         CHECK_NCCL_RET_WITH_EXCEPT(
-          (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream),
+          (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_),
           "ncclAllGather failed");
         break;
       }
@@ -89,7 +92,7 @@ class NcclGpuKernel : public GpuKernel {
           reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
         MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
         CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
-                                                             nccl_data_type_, nccl_reduce_type_, stream),
+                                                             nccl_data_type_, nccl_reduce_type_, stream, group_name_),
                                    "ncclReduceScatter failed");
         break;
       }
@@ -121,15 +124,18 @@ class NcclGpuKernel : public GpuKernel {
       output_size_list_.push_back(size);
       output_size_ += size;
     }
-    InferCommType(kernel_node);
-    collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
-    MS_EXCEPTION_IF_NULL(collective_handle_);
 
+    InferCommType(kernel_node);
+    group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
+    MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
     auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id");
     if (comm_stream_attr) {
       comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
       MS_EXCEPTION_IF_NULL(comm_stream_);
     }
+
+    collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
+    MS_EXCEPTION_IF_NULL(collective_handle_);
     return true;
   }
 
@@ -146,7 +152,7 @@ class NcclGpuKernel : public GpuKernel {
       nccl_kernel_type_ = iter->second;
     }
 
-    auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
+    auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp);
     if (reduce_op) {
       std::string type = GetValue<std::string>(reduce_op);
       if (type == "sum") {
@@ -167,6 +173,7 @@ class NcclGpuKernel : public GpuKernel {
   NcclKernelType nccl_kernel_type_;
   ncclRedOp_t nccl_reduce_type_;
   ncclDataType_t nccl_data_type_;
+  std::string group_name_;
   size_t input_size_;
   size_t output_size_;
   std::vector<size_t> input_size_list_;
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h
index f9564a0c747fc653e7f771351dae332551f1e5e8..5373f21d70ca2990baa20ff95992886d76149564 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h
@@ -23,16 +23,17 @@
 namespace mindspore {
 namespace device {
 namespace gpu {
-#define MAX_HOSTNAME_LEN 1024
-#define CHECK_RET(expression, result, message)                                                                        \
-  {                                                                                                                   \
-    auto ret = (expression);                                                                                          \
-    if (ret != result) {                                                                                              \
-      std::ostringstream oss;                                                                                         \
-      oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error " << message \
-          << " | Error Number " << ret;                                                                               \
-      pybind11::pybind11_fail(oss.str());                                                                             \
-    }                                                                                                                 \
+constexpr int MAX_HOSTNAME_LEN = 1024;
+constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
+#define CHECK_RET(expression, result, message)                                                                         \
+  {                                                                                                                    \
+    auto ret = (expression);                                                                                           \
+    if (ret != result) {                                                                                               \
+      std::ostringstream oss;                                                                                          \
+      oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error: " << message \
+          << " | Error Number " << ret;                                                                                \
+      pybind11::pybind11_fail(oss.str());                                                                              \
+    }                                                                                                                  \
   }
 }  // namespace gpu
 }  // namespace device
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h
index 424abcf47008136eb3f2f92b59ccfc7cb6efac6a..464492d50f8faf28d7c8e3772c93b207b311d5d6 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h
@@ -18,6 +18,8 @@
 #define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_
 
 #include <dlfcn.h>
+#include <vector>
+#include <string>
 
 namespace mindspore {
 namespace device {
@@ -25,6 +27,10 @@ namespace gpu {
 using InitMPI = void (*)();
 using InitNCCLComm = void (*)();
 using GetLocalRankId = int (*)();
+using CreateCommGroupFunc = bool (*)(const std::string &, const std::vector<unsigned int> &);
+using GetRankIDByGroupFunc = int (*)(const std::string &);
+using GetGroupSizeFunc = int (*)(const std::string &);
+using DestroyGroupFunc = bool (*)(const std::string &);
 
 class CollectiveInitializer {
  public:
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc
index 927c93cfafdd3407907c99557ff163c7f97ecd64..f427905afa17b1323067292f907ad9c08adc4ea9 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc
@@ -20,6 +20,7 @@
 #include <memory>
 #include <string>
 #include <iostream>
+#include <vector>
 #include "runtime/device/gpu/distribution/mpi_wrapper.h"
 #include "runtime/device/gpu/distribution/nccl_wrapper.h"
 
@@ -36,6 +37,22 @@ extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().lo
 
 extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
 
+extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
+  return MPIWrapper::instance().CreateCommGroup(group_name, ranks);
+}
+
+extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) {
+  return MPIWrapper::instance().GetRankIDByGroup(group_name);
+}
+
+extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) {
+  return MPIWrapper::instance().GetGroupSize(group_name);
+}
+
+extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) {
+  return MPIWrapper::instance().DestroyGroup(group_name);
+}
+
 extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count,
                                                  ncclDataType_t data_type, ncclRedOp_t reduce_type,
                                                  cudaStream_t stream) {
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc
index ed768fbbe5d1d37584fa47bd8c0df3b3c37984d7..08ec320cab8726a71201c41b305ad7b08bc44ce6 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc
@@ -15,9 +15,9 @@
  */
 
 #include "runtime/device/gpu/distribution/mpi_wrapper.h"
-
 #include <cuda_runtime_api.h>
 #include <string>
+#include <vector>
 #include "runtime/device/gpu/distribution/nccl_wrapper.h"
 
 namespace mindspore {
@@ -40,17 +40,82 @@ MPIWrapper &MPIWrapper::instance() {
 
 int MPIWrapper::local_rank_id() const { return local_rank_id_; }
 
+bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &group_ranks) {
+  std::vector<int> ranks(group_ranks.begin(), group_ranks.end());
+  MPI_Group mpi_group;
+  CHECK_RET(MPI_Group_incl(world_group_, ranks.size(), ranks.data(), &mpi_group), MPI_SUCCESS,
+            "Failed to produce a new group from MPI_COMM_WORLD group for " + group_name);
+  SetGroupNameToMPIGroup(group_name, mpi_group);
+
+  MPI_Comm mpi_group_comm;
+  CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, mpi_group, &mpi_group_comm), MPI_SUCCESS,
+            "Failed to create MPI communicator.");
+  if (mpi_group_comm == MPI_COMM_NULL) {
+    return false;
+  }
+
+  ncclUniqueId group_unique_id;
+  if (rank_id_ == ranks[0]) {
+    group_unique_id = NCCLWrapper::instance().nccl_unique_id();
+  }
+  MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm);
+
+  int group_rank[1];
+  int global_rank[1] = {rank_id_};
+  CHECK_RET(MPI_Group_translate_ranks(world_group_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS,
+            "Failed to translate global rank to group rank.");
+  if (group_rank[0] == MPI_UNDEFINED) {
+    return false;
+  }
+
+  ncclComm_t nccl_group_comm;
+  NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]);
+  NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm);
+  return true;
+}
+
+int MPIWrapper::GetRankIDByGroup(const std::string &group_name) {
+  CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name " + group_name);
+  MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name];
+  int rank;
+  CHECK_RET(MPI_Group_rank(mpi_group, &rank), MPI_SUCCESS, "Failed to get rank id by group name." + group_name);
+  return rank;
+}
+
+int MPIWrapper::GetGroupSize(const std::string &group_name) {
+  CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name" + group_name);
+  MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name];
+  int size;
+  CHECK_RET(MPI_Group_size(mpi_group, &size), MPI_SUCCESS, "Failed to get group size by group name." + group_name);
+  return size;
+}
+
+bool MPIWrapper::DestroyGroup(const std::string &group_name) {
+  auto group_iter = group_name_to_mpi_group_map_.find(group_name);
+  if (group_iter == group_name_to_mpi_group_map_.end()) {
+    return false;
+  }
+  group_name_to_mpi_group_map_.erase(group_name);
+  MPI_Group mpi_group = group_iter->second;
+  CHECK_RET(MPI_Group_free(&mpi_group), MPI_SUCCESS, "Failed to free MPI group for " + group_name);
+  NCCLWrapper::instance().DestroyGroup(group_name);
+  return true;
+}
+
 void MPIWrapper::Init() {
   int initialized;
   CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status.");
-
   if (initialized == 0) {
     MPI_Init(nullptr, nullptr);
   }
+
   CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id.");
   CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size.");
   NCCLWrapper::instance().set_rank(rank_id_, rank_size_);
-  AssignLocalRankId();
+  AssignLocalRankID();
+
+  CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD");
+  SetGroupNameToMPIGroup(NCCL_WORLD_GROUP, world_group_);
 
   ncclUniqueId unique_id;
   if (rank_id_ == 0) {
@@ -62,7 +127,7 @@ void MPIWrapper::Init() {
   return;
 }
 
-void MPIWrapper::AssignLocalRankId() {
+void MPIWrapper::AssignLocalRankID() {
   char host_name[MAX_HOSTNAME_LEN] = {0};
   CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed.");
   size_t host_hash = std::hash<std::string>()(host_name);
@@ -82,6 +147,10 @@ void MPIWrapper::AssignLocalRankId() {
   }
   return;
 }
+
+void MPIWrapper::SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group) {
+  group_name_to_mpi_group_map_[group_name] = mpi_group;
+}
 }  // namespace gpu
 }  // namespace device
 }  // namespace mindspore
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h
index 3d54b376cf94f7f0163d518b00579160e1c90959..19d06b32d321a240ebf6238ea1b6515dfd630e92 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h
@@ -22,6 +22,9 @@
 #include <unistd.h>
 #include <mpi.h>
 #include <iostream>
+#include <map>
+#include <string>
+#include <vector>
 #include "runtime/device/gpu/distribution/collective_common.h"
 
 namespace mindspore {
@@ -33,16 +36,23 @@ class MPIWrapper {
   MPIWrapper &operator=(const MPIWrapper &) = delete;
   static MPIWrapper &instance();
   int local_rank_id() const;
+  bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks);
+  int GetRankIDByGroup(const std::string &group_name);
+  int GetGroupSize(const std::string &group_name);
+  bool DestroyGroup(const std::string &group_name);
 
  private:
   MPIWrapper();
   ~MPIWrapper();
   void Init();
-  void AssignLocalRankId();
+  void AssignLocalRankID();
+  void SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group);
 
   int rank_id_;
   int rank_size_;
   int local_rank_id_;
+  MPI_Group world_group_;
+  std::map<std::string, MPI_Group> group_name_to_mpi_group_map_;
 };
 }  // namespace gpu
 }  // namespace device
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc
index adf0b2f6fb70fd8d8b3eec64773fc94d88e9449a..bcba5383094d784e7f14739e7f30090084f1b401 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc
@@ -40,21 +40,51 @@ void NCCLWrapper::set_rank(int rank_id, int rank_size) {
 void NCCLWrapper::InitNCCLComm() {
   CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess,
             "Failed to init nccl communicator.");
+  group_to_comm_map_[NCCL_WORLD_GROUP] = comm_;
+}
+
+void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) {
+  CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator.");
 }
 
 ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
-                                    ncclRedOp_t reduce_type, cudaStream_t stream) {
-  return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, comm_, stream);
+                                    ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) {
+  CHECK_RET(group_to_comm_map_.count(group_name), 1,
+            "Failed to find NCCL communicator for AllReduce by the group name " + group_name);
+  ncclComm_t group_comm = group_to_comm_map_[group_name];
+  return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
 }
 
 ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
-                                    cudaStream_t stream) {
-  return ncclAllGather(input_addr, output_addr, count, data_type, comm_, stream);
+                                    cudaStream_t stream, const std::string &group_name) {
+  CHECK_RET(group_to_comm_map_.count(group_name), 1,
+            "Failed to find NCCL communicator for AllGather by the group name " + group_name);
+  ncclComm_t group_comm = group_to_comm_map_[group_name];
+  return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream);
 }
 
 ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count,
-                                        ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream) {
-  return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, comm_, stream);
+                                        ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
+                                        const std::string &group_name) {
+  CHECK_RET(group_to_comm_map_.count(group_name), 1,
+            "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name);
+  ncclComm_t group_comm = group_to_comm_map_[group_name];
+  return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
+}
+
+void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) {
+  group_to_comm_map_[group_name] = comm;
+}
+
+void NCCLWrapper::DestroyGroup(const std::string &group_name) {
+  auto group_iter = group_to_comm_map_.find(group_name);
+  if (group_iter == group_to_comm_map_.end()) {
+    return;
+  }
+  group_to_comm_map_.erase(group_iter);
+  ncclComm_t group_comm = group_iter->second;
+  CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name);
+  return;
 }
 }  // namespace gpu
 }  // namespace device
diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h
index fb09efc0858acb83a38c803d5d5af66ea68726b9..9cea338c413f726f570b8c3e9b548b272f935561 100644
--- a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h
+++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h
@@ -20,6 +20,8 @@
 #include <stdio.h>
 #include <stdlib.h>
 #include <nccl.h>
+#include <string>
+#include <map>
 #include "runtime/device/gpu/distribution/collective_common.h"
 
 namespace mindspore {
@@ -34,12 +36,15 @@ class NCCLWrapper {
   void set_nccl_unique_id(ncclUniqueId unique_id);
   void set_rank(int rank_id, int rank_size);
   void InitNCCLComm();
+  void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank);
   ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
-                         ncclRedOp_t op, cudaStream_t stream);
+                         ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
   ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
-                         cudaStream_t stream);
+                         cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
   ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
-                             ncclRedOp_t op, cudaStream_t stream);
+                             ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
+  void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm);
+  void DestroyGroup(const std::string &group_name);
 
  private:
   NCCLWrapper() : rank_id_(-1), rank_size_(0) {}
@@ -50,6 +55,7 @@ class NCCLWrapper {
   int rank_size_;
   ncclUniqueId unique_id_;
   ncclComm_t comm_;
+  std::map<std::string, ncclComm_t> group_to_comm_map_;
 };
 }  // namespace gpu
 }  // namespace device
diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc
index 70adfb7467f37188d27020b72092dd18fdb57d15..de165c4aac681626ebfcdc304ac73dc0a0945e55 100644
--- a/mindspore/ccsrc/utils/comm_manager.cc
+++ b/mindspore/ccsrc/utils/comm_manager.cc
@@ -16,17 +16,27 @@
 
 #include "utils/comm_manager.h"
 #include "utils/convert_utils.h"
+
 #ifndef NO_DLIB
 #include "hccl/hcom.h"
 #endif
 
+#if defined(ENABLE_GPU)
+#include "runtime/device/gpu/distribution/collective_init.h"
+using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer;
+using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc;
+using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc;
+using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc;
+using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc;
+#endif
+
 namespace mindspore {
+#ifndef NO_DLIB
 CommManager &CommManager::GetInstance() noexcept {
   static CommManager instance("hccl");
   return instance;
 }
 
-#ifndef NO_DLIB
 #define HCCL_RUN_CHECK(op_name, group, op)                      \
   do {                                                          \
     auto hccl_result = (op);                                    \
@@ -79,7 +89,79 @@ bool CommManager::DestroyGroup(const string &group) const {
   HCCL_RUN_CHECK(string("destroy communicate group"), group, hcom_destroy_group(group.c_str()));
   return true;
 }
+#elif defined(ENABLE_GPU)
+CommManager &CommManager::GetInstance() noexcept {
+  static CommManager instance("nccl");
+  return instance;
+}
+
+bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
+  const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
+  if (!collective_handle_) {
+    MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
+  }
+  MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
+  auto create_comm_group_funcptr =
+    reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
+  MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
+  bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
+  if (!ret) {
+    MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
+    return ret;
+  }
+  return ret;
+}
+
+bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
+  const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
+  if (!collective_handle_) {
+    MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
+  }
+  auto get_rank_id_funcptr =
+    reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
+  MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
+  int rank = (*get_rank_id_funcptr)(group);
+  *rank_id = static_cast<unsigned int>(rank);
+  MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
+  return true;
+}
+
+bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
+  const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
+  if (!collective_handle_) {
+    MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
+  }
+  auto get_group_size_funcptr =
+    reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
+  MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
+  int size = (*get_group_size_funcptr)(group);
+  *rank_size = static_cast<unsigned int>(size);
+  MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
+  return true;
+}
+
+bool CommManager::DestroyGroup(const string &group) const {
+  const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
+  if (!collective_handle_) {
+    MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
+  }
+  auto destroy_group_funcptr =
+    reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
+  MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
+
+  bool ret = (*destroy_group_funcptr)(group);
+  if (!ret) {
+    MS_LOG(ERROR) << "Destroying group " << group << " failed.";
+    return ret;
+  }
+  return ret;
+}
 #else
+CommManager &CommManager::GetInstance() noexcept {
+  static CommManager instance("hccl");
+  return instance;
+}
+
 bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; }
 
 bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; }