提交 116eee52 编写于 作者: M Megvii Engine Team

build(third_party): update megray

GitOrigin-RevId: da5e05f82b5112474d51f9eab78318b1d6432742
上级 e507228e
...@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) { ...@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
} }
} }
MegRay::DType get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}
MegRay::Backend get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}
cudaStream_t get_stream(VarNode* var) { cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream; return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
} }
......
...@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() { ...@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) { for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i]; data_size *= ishp[i];
} }
data_size *= tensor.dtype().size(); auto status = m_megray_comm->send(tensor.raw_ptr(), data_size,
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx); get_megray_dtype(tensor.dtype()),
1, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
if (m_is_grad) { if (m_is_grad) {
...@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() { ...@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) { for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i]; data_size *= ishp[i];
} }
data_size *= tensor.dtype().size(); auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size,
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx); get_megray_dtype(tensor.dtype()),
0, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
} }
......
...@@ -14,6 +14,33 @@ ...@@ -14,6 +14,33 @@
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) { bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx); std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash); auto it = m_megray_comms.find(hash);
......
...@@ -13,13 +13,16 @@ ...@@ -13,13 +13,16 @@
#include <mutex> #include <mutex>
#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h" #include "megbrain/opr/group_manager.h"
#include "megray.h" #include "megray.h"
namespace mgb { namespace mgb {
namespace opr { namespace opr {
MegRay::DType get_megray_dtype(megdnn::DType);
MegRay::Backend get_megray_backend(const std::string& backend);
/*! /*!
* gather MegRay unique ids and build communicator, use hash for deduplication * gather MegRay unique ids and build communicator, use hash for deduplication
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册