提交 1bec737d 编写于 作者: M Megvii Engine Team

feat(distributed): support distributed opr for rocm

GitOrigin-RevId: 4840100d07dbaa2b7d8e3e113b444ddf81eeea51
上级 a31b7c6e
......@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT)
set(CMAKE_CXX_STANDARD 17)
endif()
if(NOT MGE_WITH_CUDA)
message(STATUS "Disable distributed support, as CUDA is not enabled.")
if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM})
message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.")
set(MGE_WITH_DISTRIBUTED OFF)
endif()
......@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
endif()
if(MGE_WITH_DISTRIBUTED)
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay)
endif()
......
......@@ -79,6 +79,8 @@ namespace {
if (g_unspec_locator_type == DT::UNSPEC) {
if (CudaCompNode::available()) {
g_unspec_locator_type = DT::CUDA;
} else if (ROCmCompNode::available()) {
g_unspec_locator_type = DT::ROCM;
} else {
g_unspec_locator_type = DT::CPU;
}
......
......@@ -217,6 +217,11 @@ public:
Locator locator() override { return m_locator; }
Locator locator_logical() override { return m_locator_logical; }
uint64_t get_uid() override { return m_uid; }
private:
uint64_t m_uid;
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROCmCompNode::CompNodeImpl);
......@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator,
m_locator_logical = locator_logical;
m_initialized = true;
#if defined(__linux__) || defined(TARGET_OS_MAC)
FILE *fp;
fp = fopen("/dev/urandom", "r");
mgb_assert(fread(&m_uid, sizeof(m_uid), 1, fp) == 1);
fclose(fp);
#else
m_uid = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch()
).count();
#endif
auto on_succ = [this](hipStream_t stream) {
auto locator = m_locator;
log_comp_node_created(locator, m_locator_logical);
......
......@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
} // anonymous namespace
/* ================= ModeTrait ================= */
......@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm(
// add input
mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
if (inputs.size() > 0) {
mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
add_input({inputs[0]});
}
......@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm(
const auto& cns = config.comp_node();
mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
if (cns.size() > 0) {
mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
"CollectiveComm currectly only supports CUDA");
output(0)->comp_node(cns[0]);
} else {
output(0)->comp_node(inputs[0]->comp_node());
......@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() {
reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true;
}
......
......@@ -18,10 +18,6 @@
using namespace mgb;
using namespace opr;
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
/* ===================== RemoteSend ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
......@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true;
}
......@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_megray_ctx = get_megray_context(output(0)->comp_node());
m_init = true;
}
......
......@@ -10,6 +10,7 @@
*/
#include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"
using namespace mgb;
using namespace opr;
......@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "rccl") {
return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
......@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
}
}
std::shared_ptr<MegRay::Context> mgb::opr::get_megray_context(CompNode comp_node){
#if MGB_CUDA
return MegRay::CudaContext::make(CompNodeEnv::from_comp_node(comp_node).cuda_env().stream);
#elif MGB_ROCM
return MegRay::HipContext::make(CompNodeEnv::from_comp_node(comp_node).rocm_env().stream);
#else
#error "neither CUDA nor ROCm is enabled"
#endif
}
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);
......
#include "megbrain_build_config.h"
#if MGB_CUDA
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
......@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) {
DISCARD_RETVAL(client->recv(reply));
add_socket(client);
}
#endif // MGB_CUDA
......@@ -12,7 +12,9 @@
#pragma once
#include <mutex>
#include <memory>
#include "megbrain/comp_node.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
......@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType);
MegRay::Backend get_megray_backend(const std::string& backend);
std::shared_ptr<MegRay::Context> get_megray_context(CompNode comp_node);
/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/
......
......@@ -2,7 +2,6 @@
#include "megbrain_build_config.h"
#if MGB_CUDA
#include <unistd.h>
#include <cassert>
#include <iostream>
......@@ -101,4 +100,3 @@ private:
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
};
} // namespace ZmqRpc
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册