From d28f6f7b822c2e98a8b910e66d2d5bfd03152082 Mon Sep 17 00:00:00 2001 From: maxhuiy <1508399706@qq.com> Date: Sun, 30 Jan 2022 10:47:17 +0800 Subject: [PATCH] feat(cncl_mlu): add cncl dev for mlu distributed backend (#39294) --- CMakeLists.txt | 8 + cmake/neuware.cmake | 9 +- paddle/fluid/platform/CMakeLists.txt | 4 + paddle/fluid/platform/collective_helper.h | 100 ++++++++++ .../fluid/platform/device/mlu/CMakeLists.txt | 1 + .../fluid/platform/device/mlu/cncl_helper.h | 57 ++++++ .../platform/device/mlu/device_context.h | 15 ++ paddle/fluid/platform/device/mlu/enforce.h | 14 ++ .../fluid/platform/device/mlu/enforce_test.cc | 10 + .../device/mlu/mlu_collective_helper.cc | 179 ++++++++++++++++++ paddle/fluid/platform/device/mlu/mlu_info.h | 6 + 11 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/platform/device/mlu/cncl_helper.h create mode 100644 paddle/fluid/platform/device/mlu/mlu_collective_helper.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 549ed9d8543..cd131e2d708 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,7 @@ option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON) option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF) +option(WITH_CNCL "Compile PaddlePaddle with CNCL support" OFF) option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON) option(WITH_ARM "Compile PaddlePaddle with arm support" OFF) option(WITH_SW "Compile PaddlePaddle with sw support" OFF) @@ -292,6 +293,13 @@ if (NOT WITH_XPU AND WITH_XPU_BKCL) "Disable BKCL when compiling without XPU" FORCE) endif() +if (NOT WITH_MLU AND WITH_CNCL) + MESSAGE(WARNING + "Disable CNCL when compiling without MLU. Force WITH_MLU=OFF.") + set(WITH_MLU OFF CACHE STRING + "Disable CNCL when compiling without MLU" FORCE) +endif() + if(WITH_NCCL) add_definitions("-DPADDLE_WITH_NCCL") include(nccl) diff --git a/cmake/neuware.cmake b/cmake/neuware.cmake index 7219f5f7259..811c8d664a0 100644 --- a/cmake/neuware.cmake +++ b/cmake/neuware.cmake @@ -19,4 +19,11 @@ set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") -TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) +if(WITH_CNCL) + MESSAGE(STATUS "Compile with CNCL!") + ADD_DEFINITIONS(-DPADDLE_WITH_CNCL) + set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so) + TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) +else() + TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) +endif() diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 7f54903e69c..e35b586dc90 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -133,6 +133,10 @@ if(WITH_ASCEND_CL) target_link_libraries(collective_helper npu_collective_helper) endif() +if(WITH_CNCL) + target_link_libraries(collective_helper mlu_collective_helper) +endif() + if(WITH_GPU OR WITH_ROCM) target_link_libraries(device_context gpu_resource_pool) endif() diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 62a07669259..2c0067bb152 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -24,6 +24,9 @@ #include "paddle/fluid/platform/device/npu/dynload/hccl.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/device/mlu/device_context.h" +#endif namespace paddle { namespace platform { @@ -333,5 +336,102 @@ class BKCLCommContext { }; #endif +#if defined(PADDLE_WITH_CNCL) +// In order to apply hierarchical communication with CNCL, we need +// a communication ring contains CNCL communicators associated to a global +// cnclUniqueId. E.g. for a hierarchical case, +// +// 11 - 12 21 - 22 +// | | | | +// 13 - 14 - 23 - 24 +// | | +// 31 - 32 - 41 - 42 +// | | | | +// 33 - 34 43 - 44 +// +// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24), +// (31,32,33,34), (41,42,43,44) as bottoms respectively. +// +// We could also use a single communication ring for the flatten case +// +// The CNCLComm instance is created and reversed in the CNCLCommContext +// singleton with a global user specified group id. +class MLUDeviceContext; + +class CNCLComm { + public: + virtual int ring_id() const = 0; + virtual int nranks() const = 0; + virtual int rank() const = 0; + virtual int device_id() const = 0; + virtual cnclComm_t comm() const = 0; + virtual mluStream stream() const = 0; + virtual MLUDeviceContext* dev_context() const = 0; + virtual ~CNCLComm() = default; +}; + +// A singleton CNCL communicator context reserves communication ring ids +class CNCLCommContext { + public: + static CNCLCommContext& Instance() { + static CNCLCommContext comm_ctx; + return comm_ctx; + } + + CNCLComm* CreateComm(cnclCliqueId* cncl_id, int nranks, int rank, int dev_id, + int ring_id = 0); + void CreateAllCNCLComms(const std::vector& dev_ids, int ring_id = 0); + + // a latter comm with the same dev_id and the same ring_id + // will override the former + CNCLComm* AssignCNCLComm(cnclComm_t comm, int nranks, int rank, int dev_id, + int ring_id = 0); + + // retrieve a communicator by the ring id in multiprocessing mode + CNCLComm* Get(int ring_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator in ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1, + platform::errors::InvalidArgument( + "One device id should be specified to retrieve from " + "multiple communicators.")); + return comm_map_.at(ring_id).begin()->second.get(); + } + + // retrieve a communicator by the ring id and the device id + CNCLComm* Get(int ring_id, int dev_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator of ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_GT( + comm_map_.at(ring_id).count(dev_id), 0, + platform::errors::InvalidArgument( + "Communicator at device id %d has not been initialized in ring %d.", + dev_id, ring_id)); + return comm_map_.at(ring_id).at(dev_id).get(); + } + + // retrieve a communicator by the ring id and place + CNCLComm* Get(int ring_id, Place place) const { + return Get(ring_id, place.device); + } + + private: + std::once_flag once_flag_; + std::mutex comm_map_mutex_; + // ring id to dev-CNCLComm + std::map>> comm_map_; + + void ReleaseCNCLComms(); + + CNCLCommContext() = default; + DISABLE_COPY_AND_ASSIGN(CNCLCommContext); +}; + +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/mlu/CMakeLists.txt b/paddle/fluid/platform/device/mlu/CMakeLists.txt index e8b794a03e3..724776bfad2 100644 --- a/paddle/fluid/platform/device/mlu/CMakeLists.txt +++ b/paddle/fluid/platform/device/mlu/CMakeLists.txt @@ -8,3 +8,4 @@ cc_library(mlu_info SRCS mlu_info.cc DEPS enforce glog monitor neuware_lib) cc_library(mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_manager eigen3 ${MKLDNN_CTX_DEPS}) cc_library(mlu_device_context SRCS device_context.cc DEPS mlu_stream) cc_test(mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context) +cc_library(mlu_collective_helper SRCS mlu_collective_helper.cc DEPS mlu_stream mlu_info) diff --git a/paddle/fluid/platform/device/mlu/cncl_helper.h b/paddle/fluid/platform/device/mlu/cncl_helper.h new file mode 100644 index 00000000000..2f9bed01426 --- /dev/null +++ b/paddle/fluid/platform/device/mlu/cncl_helper.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_CNCL +#include + +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace platform { + +inline cnclDataType_t ToCNCLDataType(framework::proto::VarType::Type type) { + if (type == framework::proto::VarType::FP32) { + return cnclFloat32; + } else if (type == framework::proto::VarType::FP16) { + return cnclFloat16; + } else if (type == framework::proto::VarType::INT32) { + return cnclInt32; + } else if (type == framework::proto::VarType::INT16) { + return cnclInt16; + } else if (type == framework::proto::VarType::INT8) { + return cnclInt8; + } else if (type == framework::proto::VarType::UINT8) { + return cnclUint8; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "This datatype in cncl is not supported.")); + } +} + +} // namespace platform +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h index 2692f3a248a..a3f3bda17c8 100644 --- a/paddle/fluid/platform/device/mlu/device_context.h +++ b/paddle/fluid/platform/device/mlu/device_context.h @@ -15,6 +15,9 @@ limitations under the License. */ #include "paddle/fluid/platform/device/mlu/enforce.h" #include "paddle/fluid/platform/device/mlu/mlu_stream.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CNCL +#include +#endif namespace Eigen { struct DefaultDevice; @@ -88,6 +91,14 @@ class MLUDeviceContext : public DeviceContext { /*! \brief Return mlu stream in the device context. */ mluStream stream() const; +#ifdef PADDLE_WITH_CNCL + /*! \brief Return cncl communicators. */ + cnclComm_t cncl_comm() const { return cncl_comm_; } + + /*! \brief Set cncl communicators. */ + void set_cncl_comm(cnclComm_t comm) { cncl_comm_ = comm; } +#endif + template void RecordEvent(mluEventHandle ev, Callback callback) const { return context()->Stream()->RecordEvent(ev, callback); @@ -132,6 +143,10 @@ class MLUDeviceContext : public DeviceContext { thread_ctx_; static thread_local std::mutex ctx_mtx_; +#ifdef PADDLE_WITH_CNCL + cnclComm_t cncl_comm_{nullptr}; +#endif + DISABLE_COPY_AND_ASSIGN(MLUDeviceContext); }; diff --git a/paddle/fluid/platform/device/mlu/enforce.h b/paddle/fluid/platform/device/mlu/enforce.h index eecbad53cab..5c9871d7bce 100644 --- a/paddle/fluid/platform/device/mlu/enforce.h +++ b/paddle/fluid/platform/device/mlu/enforce.h @@ -42,6 +42,9 @@ struct MLUStatusType {}; DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); +#ifdef PADDLE_WITH_CNCL +DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL); +#endif } // namespace details @@ -80,6 +83,17 @@ inline std::string build_mlu_error_msg(cnStatus stat) { return sout.str(); } +/*************** CNCL ERROR ***************/ +#ifdef PADDLE_WITH_CNCL +inline bool is_error(cnclStatus e) { return e != CNCL_RET_SUCCESS; } + +inline std::string build_mlu_error_msg(cnclStatus e) { + std::ostringstream sout; + sout << "MLU CNCL error(" << e << "), " << cnclGetErrorStr(e) << ". "; + return sout.str(); +} +#endif + #define PADDLE_ENFORCE_MLU_SUCCESS(COND) \ do { \ auto __cond__ = (COND); \ diff --git a/paddle/fluid/platform/device/mlu/enforce_test.cc b/paddle/fluid/platform/device/mlu/enforce_test.cc index 7241afba6aa..4ff7b12c446 100644 --- a/paddle/fluid/platform/device/mlu/enforce_test.cc +++ b/paddle/fluid/platform/device/mlu/enforce_test.cc @@ -58,5 +58,15 @@ TEST(mlu_enforce, mlu_success) { CheckMluStatusFailure(CN_ERROR_INVALID_VALUE, "invalid argument")); EXPECT_TRUE(CheckMluStatusFailure(CN_MEMORY_ERROR_OUT_OF_MEMORY, "device has no memory to alloc")); +#ifdef PADDLE_WITH_CNCL + EXPECT_TRUE(CheckMluStatusSuccess(CNCL_RET_SUCCESS)); + EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INTERNAL, "CNCL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NULL_POINTER, "CNCL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INIT, "CNCL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NOT_INIT, "CNCL error")); + EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_REINIT, "CNCL error")); + EXPECT_TRUE( + CheckMluStatusFailure(CNCL_RET_ERR_INVALID_VERSION, "CNCL error")); +#endif } #endif diff --git a/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc b/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc new file mode 100644 index 00000000000..7708267c1bc --- /dev/null +++ b/paddle/fluid/platform/device/mlu/mlu_collective_helper.cc @@ -0,0 +1,179 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(PADDLE_WITH_CNCL) +#include +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/enforce.h" + +namespace paddle { +namespace platform { + +class CNCLCommImpl : public CNCLComm { + public: + void set_ring_id(int ring_id) { ring_id_ = ring_id; } + int ring_id() const override { return ring_id_; } + + void set_nranks(int nranks) { nranks_ = nranks; } + int nranks() const override { return nranks_; } + + void set_rank(int rank) { rank_ = rank; } + int rank() const override { return rank_; } + + int device_id() const override { return dev_ctx_->GetPlace().device; } + + void set_comm(cnclComm_t comm) { comm_ = comm; } + cnclComm_t comm() const override { return comm_; } + + mluStream stream() const override { return dev_ctx_->stream(); } + + void set_dev_ctx(std::unique_ptr&& dev_ctx) { + dev_ctx_ = std::move(dev_ctx); + } + MLUDeviceContext* dev_context() const override { return dev_ctx_.get(); } + + ~CNCLCommImpl() { + if (comm_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnclFreeComm(comm_)); + } + } + + private: + int ring_id_; + int nranks_; + int rank_; + cnclComm_t comm_; + std::unique_ptr dev_ctx_; +}; + +CNCLComm* CNCLCommContext::CreateComm(cnclCliqueId* cncl_id, int nranks, + int rank, int dev_id, int ring_id) { + PADDLE_ENFORCE_NOT_NULL(cncl_id, + platform::errors::InvalidArgument( + "The cncl unique id should not be null.")); + PADDLE_ENFORCE_GT( + nranks, 1, + platform::errors::InvalidArgument( + "Expected nranks > 1. But received nranks is %d.", nranks)); + PADDLE_ENFORCE_GE(rank, 0, + platform::errors::InvalidArgument( + "Expected rank >= 0. But received rank is %d.", rank)); + PADDLE_ENFORCE_LT( + rank, nranks, + platform::errors::InvalidArgument( + "Expected rank < nranks. But received rank is %d, nranks is %d.", + rank, nranks)); + PADDLE_ENFORCE_GE( + dev_id, 0, + platform::errors::InvalidArgument( + "Expected dev_id >= 0. But received dev_id is %d.", dev_id)); + + cnclComm_t comm; + int dev_list[] = {dev_id}; + int rank_list[] = {rank}; + SetMLUDeviceId(dev_id); + PADDLE_ENFORCE_MLU_SUCCESS( + cnclInitComms(&comm, 1, dev_list, rank_list, nranks, cncl_id)); + + auto* comm_wrapper = AssignCNCLComm(comm, nranks, rank, dev_id, ring_id); + + VLOG(1) << "cncl communicator of rank " << rank << " in ring " << ring_id + << " has been created on device " << dev_id; + + std::call_once(once_flag_, []() { + std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); }); + }); + + return comm_wrapper; +} + +void CNCLCommContext::CreateAllCNCLComms(const std::vector& dev_ids, + int ring_id) { + PADDLE_ENFORCE_GT( + dev_ids.size(), 0, + platform::errors::InvalidArgument("Expected the size of dev_ids > 0. But " + "received the size of dev_ids is %d.", + dev_ids.size())); + + const int kDevices = dev_ids.size(); + cnclComm_t comms[kDevices]; + int* rank_list = new int[kDevices]; + for (int i = 0; i < kDevices; i++) { + rank_list[i] = i; + } + cnclCliqueId clique_id; + PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&clique_id)); + PADDLE_ENFORCE_MLU_SUCCESS(cnclInitComms(comms, dev_ids.size(), + dev_ids.data(), rank_list, + dev_ids.size(), &clique_id)); + + PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Expected comm_map_.count(ring_id) = 0. But received " + "comm_map_.count(ring_id) is %d.", + comm_map_.count(ring_id))); + for (size_t i = 0; i < dev_ids.size(); ++i) { + AssignCNCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id); + VLOG(1) << "cncl communicator of rank " << i << " in ring " << ring_id + << " has been created on device " << dev_ids[i]; + } + + std::call_once(once_flag_, []() { + std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); }); + }); + delete[] rank_list; +} + +CNCLComm* CNCLCommContext::AssignCNCLComm(cnclComm_t comm, int nranks, int rank, + int dev_id, int ring_id) { + std::unique_ptr dev_ctx( + new MLUDeviceContext(MLUPlace(dev_id))); + + CNCLCommImpl* c = new CNCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(nranks); + c->set_rank(rank); + c->set_comm(comm); + c->set_dev_ctx(std::move(dev_ctx)); + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; + + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); + + if (ring_id == 0) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::MLUPlace(dev_id))); + dev_ctx->set_cncl_comm(comm); + } + + return comm_map_[ring_id][dev_id].get(); +} + +void CNCLCommContext::ReleaseCNCLComms() { + for (auto& p : comm_map_) { + for (auto& q : p.second) { + q.second.reset(); + } + } +} + +} // namespace platform +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h index 4588dd66677..fcf06cb4f1c 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -18,6 +18,9 @@ limitations under the License. */ #include #include #include +#ifdef PADDLE_WITH_CNCL +#include +#endif #include namespace paddle { @@ -25,6 +28,9 @@ namespace paddle { using cnStatus = CNresult; using cnrtStatus = cnrtRet_t; using cnnlStatus = cnnlStatus_t; +#ifdef PADDLE_WITH_CNCL +using cnclStatus = cnclResult_t; +#endif using mluStream = cnrtQueue_t; using mluCnnlHandle = cnnlHandle_t; using mluEventHandle = CNnotifier; -- GitLab