提交 1bec737d 编写于 作者: M Megvii Engine Team

feat(distributed): support distributed opr for rocm

GitOrigin-RevId: 4840100d07dbaa2b7d8e3e113b444ddf81eeea51
上级 a31b7c6e
...@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT) ...@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
endif() endif()
if(NOT MGE_WITH_CUDA) if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM})
message(STATUS "Disable distributed support, as CUDA is not enabled.") message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.")
set(MGE_WITH_DISTRIBUTED OFF) set(MGE_WITH_DISTRIBUTED OFF)
endif() endif()
...@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) ...@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
endif() endif()
if(MGE_WITH_DISTRIBUTED) if(MGE_WITH_DISTRIBUTED)
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay)
endif() endif()
......
...@@ -79,6 +79,8 @@ namespace { ...@@ -79,6 +79,8 @@ namespace {
if (g_unspec_locator_type == DT::UNSPEC) { if (g_unspec_locator_type == DT::UNSPEC) {
if (CudaCompNode::available()) { if (CudaCompNode::available()) {
g_unspec_locator_type = DT::CUDA; g_unspec_locator_type = DT::CUDA;
} else if (ROCmCompNode::available()) {
g_unspec_locator_type = DT::ROCM;
} else { } else {
g_unspec_locator_type = DT::CPU; g_unspec_locator_type = DT::CPU;
} }
......
...@@ -217,6 +217,11 @@ public: ...@@ -217,6 +217,11 @@ public:
Locator locator() override { return m_locator; } Locator locator() override { return m_locator; }
Locator locator_logical() override { return m_locator_logical; } Locator locator_logical() override { return m_locator_logical; }
uint64_t get_uid() override { return m_uid; }
private:
uint64_t m_uid;
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl);
...@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator, ...@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator,
m_locator_logical = locator_logical; m_locator_logical = locator_logical;
m_initialized = true; m_initialized = true;
#if defined(__linux__) || defined(TARGET_OS_MAC)
FILE *fp;
fp = fopen("/dev/urandom", "r");
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1);
fclose(fp);
#else
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch()
).count();
#endif
auto on_succ = [this](hipStream_t stream) { auto on_succ = [this](hipStream_t stream) {
auto locator = m_locator; auto locator = m_locator;
log_comp_node_created(locator, m_locator_logical); log_comp_node_created(locator, m_locator_logical);
......
...@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { ...@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
} }
} }
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
} // anonymous namespace } // anonymous namespace
/* ================= ModeTrait ================= */ /* ================= ModeTrait ================= */
...@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm( ...@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm(
// add input // add input
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size()); mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
if (inputs.size() > 0) { if (inputs.size() > 0) {
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
add_input({inputs[0]}); add_input({inputs[0]});
} }
...@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm( ...@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm(
const auto& cns = config.comp_node(); const auto& cns = config.comp_node();
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size()); mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
if (cns.size() > 0) { if (cns.size() > 0) {
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]); output(0)->comp_node(cns[0]);
} else { } else {
output(0)->comp_node(inputs[0]->comp_node()); output(0)->comp_node(inputs[0]->comp_node());
...@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() { ...@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() {
reg_info.hash, m_key, m_nr_devices, m_rank, reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client); get_megray_backend(m_backend), m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true; m_init = true;
} }
......
...@@ -18,10 +18,6 @@ ...@@ -18,10 +18,6 @@
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
/* ===================== RemoteSend ===================== */ /* ===================== RemoteSend ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
...@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() { ...@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm( m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true; m_init = true;
} }
...@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() { ...@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm( m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true; m_init = true;
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include "megbrain/opr/megray_helper.h" #include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
...@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { ...@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") { if (backend == "nccl") {
return MegRay::MEGRAY_NCCL; return MegRay::MEGRAY_NCCL;
} else if (backend == "rccl") {
return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") { } else if (backend == "ucx") {
return MegRay::MEGRAY_UCX; return MegRay::MEGRAY_UCX;
} else { } else {
...@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { ...@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
} }
} }
std::shared_ptr<MegRay::Context> mgb::opr::get_megray_context(CompNode comp_node){
#if MGB_CUDA
return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream);
#elif MGB_ROCM
return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream);
#else
#error "neither CUDA nor ROCm is enabled"
#endif
}
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx); std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash); auto it = m_megray_comms.find(hash);
......
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_CUDA
#include "megbrain/opr/zmq_rpc.h" #include "megbrain/opr/zmq_rpc.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/exception.h" #include "megbrain/exception.h"
...@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) { ...@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) {
DISCARD_RETVAL(client->recv(reply)); DISCARD_RETVAL(client->recv(reply));
add_socket(client); add_socket(client);
} }
#endif // MGB_CUDA
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#pragma once #pragma once
#include <mutex> #include <mutex>
#include <memory>
#include "megbrain/comp_node.h"
#include "megbrain/opr/group_manager.h" #include "megbrain/opr/group_manager.h"
#include "megray.h" #include "megray.h"
...@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType); ...@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType);
MegRay::Backend get_megray_backend(const std::string& backend); MegRay::Backend get_megray_backend(const std::string& backend);
std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node);
/*! /*!
* gather MegRay unique ids and build communicator, use hash for deduplication * gather MegRay unique ids and build communicator, use hash for deduplication
*/ */
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#if MGB_CUDA
#include <unistd.h> #include <unistd.h>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
...@@ -101,4 +100,3 @@ private: ...@@ -101,4 +100,3 @@ private:
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets; std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
}; };
} // namespace ZmqRpc } // namespace ZmqRpc
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册