未验证 提交 d28f6f7b 编写于 作者: mhhhh1's avatar mhhhh1 提交者: GitHub

feat(cncl_mlu): add cncl dev for mlu distributed backend (#39294)

上级 eefe5feb
......@@ -230,6 +230,7 @@ option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON)
option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF)
option(WITH_CNCL "Compile PaddlePaddle with CNCL support" OFF)
option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON)
option(WITH_ARM "Compile PaddlePaddle with arm support" OFF)
option(WITH_SW "Compile PaddlePaddle with sw support" OFF)
......@@ -292,6 +293,13 @@ if (NOT WITH_XPU AND WITH_XPU_BKCL)
"Disable BKCL when compiling without XPU" FORCE)
endif()
if (NOT WITH_MLU AND WITH_CNCL)
MESSAGE(WARNING
"Disable CNCL when compiling without MLU. Force WITH_MLU=OFF.")
set(WITH_MLU OFF CACHE STRING
"Disable CNCL when compiling without MLU" FORCE)
endif()
if(WITH_NCCL)
add_definitions("-DPADDLE_WITH_NCCL")
include(nccl)
......
......@@ -19,4 +19,11 @@ set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so)
generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake")
TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
if(WITH_CNCL)
MESSAGE(STATUS "Compile with CNCL!")
ADD_DEFINITIONS(-DPADDLE_WITH_CNCL)
set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so)
TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
else()
TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
endif()
......@@ -133,6 +133,10 @@ if(WITH_ASCEND_CL)
target_link_libraries(collective_helper npu_collective_helper)
endif()
if(WITH_CNCL)
target_link_libraries(collective_helper mlu_collective_helper)
endif()
if(WITH_GPU OR WITH_ROCM)
target_link_libraries(device_context gpu_resource_pool)
endif()
......
......@@ -24,6 +24,9 @@
#include "paddle/fluid/platform/device/npu/dynload/hccl.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/device_context.h"
#endif
namespace paddle {
namespace platform {
......@@ -333,5 +336,102 @@ class BKCLCommContext {
};
#endif
#if defined(PADDLE_WITH_CNCL)
// In order to apply hierarchical communication with CNCL, we need
// a communication ring contains CNCL communicators associated to a global
// cnclUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The CNCLComm instance is created and reversed in the CNCLCommContext
// singleton with a global user specified group id.
class MLUDeviceContext;
class CNCLComm {
public:
virtual int ring_id() const = 0;
virtual int nranks() const = 0;
virtual int rank() const = 0;
virtual int device_id() const = 0;
virtual cnclComm_t comm() const = 0;
virtual mluStream stream() const = 0;
virtual MLUDeviceContext* dev_context() const = 0;
virtual ~CNCLComm() = default;
};
// A singleton CNCL communicator context reserves communication ring ids
class CNCLCommContext {
public:
static CNCLCommContext& Instance() {
static CNCLCommContext comm_ctx;
return comm_ctx;
}
CNCLComm* CreateComm(cnclCliqueId* cncl_id, int nranks, int rank, int dev_id,
int ring_id = 0);
void CreateAllCNCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
// a latter comm with the same dev_id and the same ring_id
// will override the former
CNCLComm* AssignCNCLComm(cnclComm_t comm, int nranks, int rank, int dev_id,
int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
CNCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator in ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1,
platform::errors::InvalidArgument(
"One device id should be specified to retrieve from "
"multiple communicators."));
return comm_map_.at(ring_id).begin()->second.get();
}
// retrieve a communicator by the ring id and the device id
CNCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Communicator of ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_GT(
comm_map_.at(ring_id).count(dev_id), 0,
platform::errors::InvalidArgument(
"Communicator at device id %d has not been initialized in ring %d.",
dev_id, ring_id));
return comm_map_.at(ring_id).at(dev_id).get();
}
// retrieve a communicator by the ring id and place
CNCLComm* Get(int ring_id, Place place) const {
return Get(ring_id, place.device);
}
private:
std::once_flag once_flag_;
std::mutex comm_map_mutex_;
// ring id to dev-CNCLComm
std::map<int, std::map<int, std::unique_ptr<CNCLComm>>> comm_map_;
void ReleaseCNCLComms();
CNCLCommContext() = default;
DISABLE_COPY_AND_ASSIGN(CNCLCommContext);
};
#endif
} // namespace platform
} // namespace paddle
......@@ -8,3 +8,4 @@ cc_library(mlu_info SRCS mlu_info.cc DEPS enforce glog monitor neuware_lib)
cc_library(mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_manager eigen3 ${MKLDNN_CTX_DEPS})
cc_library(mlu_device_context SRCS device_context.cc DEPS mlu_stream)
cc_test(mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context)
cc_library(mlu_collective_helper SRCS mlu_collective_helper.cc DEPS mlu_stream mlu_info)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#include <stdio.h>
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace platform {
inline cnclDataType_t ToCNCLDataType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP32) {
return cnclFloat32;
} else if (type == framework::proto::VarType::FP16) {
return cnclFloat16;
} else if (type == framework::proto::VarType::INT32) {
return cnclInt32;
} else if (type == framework::proto::VarType::INT16) {
return cnclInt16;
} else if (type == framework::proto::VarType::INT8) {
return cnclInt8;
} else if (type == framework::proto::VarType::UINT8) {
return cnclUint8;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in cncl is not supported."));
}
}
} // namespace platform
} // namespace paddle
#endif
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/device/mlu/mlu_stream.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
namespace Eigen {
struct DefaultDevice;
......@@ -88,6 +91,14 @@ class MLUDeviceContext : public DeviceContext {
/*! \brief Return mlu stream in the device context. */
mluStream stream() const;
#ifdef PADDLE_WITH_CNCL
/*! \brief Return cncl communicators. */
cnclComm_t cncl_comm() const { return cncl_comm_; }
/*! \brief Set cncl communicators. */
void set_cncl_comm(cnclComm_t comm) { cncl_comm_ = comm; }
#endif
template <typename Callback>
void RecordEvent(mluEventHandle ev, Callback callback) const {
return context()->Stream()->RecordEvent(ev, callback);
......@@ -132,6 +143,10 @@ class MLUDeviceContext : public DeviceContext {
thread_ctx_;
static thread_local std::mutex ctx_mtx_;
#ifdef PADDLE_WITH_CNCL
cnclComm_t cncl_comm_{nullptr};
#endif
DISABLE_COPY_AND_ASSIGN(MLUDeviceContext);
};
......
......@@ -42,6 +42,9 @@ struct MLUStatusType {};
DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT);
DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL);
DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN);
#ifdef PADDLE_WITH_CNCL
DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL);
#endif
} // namespace details
......@@ -80,6 +83,17 @@ inline std::string build_mlu_error_msg(cnStatus stat) {
return sout.str();
}
/*************** CNCL ERROR ***************/
#ifdef PADDLE_WITH_CNCL
inline bool is_error(cnclStatus e) { return e != CNCL_RET_SUCCESS; }
inline std::string build_mlu_error_msg(cnclStatus e) {
std::ostringstream sout;
sout << "MLU CNCL error(" << e << "), " << cnclGetErrorStr(e) << ". ";
return sout.str();
}
#endif
#define PADDLE_ENFORCE_MLU_SUCCESS(COND) \
do { \
auto __cond__ = (COND); \
......
......@@ -58,5 +58,15 @@ TEST(mlu_enforce, mlu_success) {
CheckMluStatusFailure(CN_ERROR_INVALID_VALUE, "invalid argument"));
EXPECT_TRUE(CheckMluStatusFailure(CN_MEMORY_ERROR_OUT_OF_MEMORY,
"device has no memory to alloc"));
#ifdef PADDLE_WITH_CNCL
EXPECT_TRUE(CheckMluStatusSuccess(CNCL_RET_SUCCESS));
EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INTERNAL, "CNCL error"));
EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NULL_POINTER, "CNCL error"));
EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_INIT, "CNCL error"));
EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_NOT_INIT, "CNCL error"));
EXPECT_TRUE(CheckMluStatusFailure(CNCL_RET_ERR_REINIT, "CNCL error"));
EXPECT_TRUE(
CheckMluStatusFailure(CNCL_RET_ERR_INVALID_VERSION, "CNCL error"));
#endif
}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CNCL)
#include <utility>
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/enforce.h"
namespace paddle {
namespace platform {
class CNCLCommImpl : public CNCLComm {
public:
void set_ring_id(int ring_id) { ring_id_ = ring_id; }
int ring_id() const override { return ring_id_; }
void set_nranks(int nranks) { nranks_ = nranks; }
int nranks() const override { return nranks_; }
void set_rank(int rank) { rank_ = rank; }
int rank() const override { return rank_; }
int device_id() const override { return dev_ctx_->GetPlace().device; }
void set_comm(cnclComm_t comm) { comm_ = comm; }
cnclComm_t comm() const override { return comm_; }
mluStream stream() const override { return dev_ctx_->stream(); }
void set_dev_ctx(std::unique_ptr<MLUDeviceContext>&& dev_ctx) {
dev_ctx_ = std::move(dev_ctx);
}
MLUDeviceContext* dev_context() const override { return dev_ctx_.get(); }
~CNCLCommImpl() {
if (comm_) {
PADDLE_ENFORCE_MLU_SUCCESS(cnclFreeComm(comm_));
}
}
private:
int ring_id_;
int nranks_;
int rank_;
cnclComm_t comm_;
std::unique_ptr<MLUDeviceContext> dev_ctx_;
};
CNCLComm* CNCLCommContext::CreateComm(cnclCliqueId* cncl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(cncl_id,
platform::errors::InvalidArgument(
"The cncl unique id should not be null."));
PADDLE_ENFORCE_GT(
nranks, 1,
platform::errors::InvalidArgument(
"Expected nranks > 1. But received nranks is %d.", nranks));
PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT(
rank, nranks,
platform::errors::InvalidArgument(
"Expected rank < nranks. But received rank is %d, nranks is %d.",
rank, nranks));
PADDLE_ENFORCE_GE(
dev_id, 0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
cnclComm_t comm;
int dev_list[] = {dev_id};
int rank_list[] = {rank};
SetMLUDeviceId(dev_id);
PADDLE_ENFORCE_MLU_SUCCESS(
cnclInitComms(&comm, 1, dev_list, rank_list, nranks, cncl_id));
auto* comm_wrapper = AssignCNCLComm(comm, nranks, rank, dev_id, ring_id);
VLOG(1) << "cncl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id;
std::call_once(once_flag_, []() {
std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); });
});
return comm_wrapper;
}
void CNCLCommContext::CreateAllCNCLComms(const std::vector<int>& dev_ids,
int ring_id) {
PADDLE_ENFORCE_GT(
dev_ids.size(), 0,
platform::errors::InvalidArgument("Expected the size of dev_ids > 0. But "
"received the size of dev_ids is %d.",
dev_ids.size()));
const int kDevices = dev_ids.size();
cnclComm_t comms[kDevices];
int* rank_list = new int[kDevices];
for (int i = 0; i < kDevices; i++) {
rank_list[i] = i;
}
cnclCliqueId clique_id;
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCliqueId(&clique_id));
PADDLE_ENFORCE_MLU_SUCCESS(cnclInitComms(comms, dev_ids.size(),
dev_ids.data(), rank_list,
dev_ids.size(), &clique_id));
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument(
"Expected comm_map_.count(ring_id) = 0. But received "
"comm_map_.count(ring_id) is %d.",
comm_map_.count(ring_id)));
for (size_t i = 0; i < dev_ids.size(); ++i) {
AssignCNCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id);
VLOG(1) << "cncl communicator of rank " << i << " in ring " << ring_id
<< " has been created on device " << dev_ids[i];
}
std::call_once(once_flag_, []() {
std::atexit([]() { CNCLCommContext::Instance().ReleaseCNCLComms(); });
});
delete[] rank_list;
}
CNCLComm* CNCLCommContext::AssignCNCLComm(cnclComm_t comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<MLUDeviceContext> dev_ctx(
new MLUDeviceContext(MLUPlace(dev_id)));
CNCLCommImpl* c = new CNCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<CNCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<CNCLComm>(c));
comm_map_mutex_.unlock();
if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::MLUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::MLUPlace(dev_id)));
dev_ctx->set_cncl_comm(comm);
}
return comm_map_[ring_id][dev_id].get();
}
void CNCLCommContext::ReleaseCNCLComms() {
for (auto& p : comm_map_) {
for (auto& q : p.second) {
q.second.reset();
}
}
}
} // namespace platform
} // namespace paddle
#endif
......@@ -18,6 +18,9 @@ limitations under the License. */
#include <cn_api.h>
#include <cnnl.h>
#include <cnrt.h>
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
#include <vector>
namespace paddle {
......@@ -25,6 +28,9 @@ namespace paddle {
using cnStatus = CNresult;
using cnrtStatus = cnrtRet_t;
using cnnlStatus = cnnlStatus_t;
#ifdef PADDLE_WITH_CNCL
using cnclStatus = cnclResult_t;
#endif
using mluStream = cnrtQueue_t;
using mluCnnlHandle = cnnlHandle_t;
using mluEventHandle = CNnotifier;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册