提交 116eee52 编写于 作者: M Megvii Engine Team

build(third_party): update megray

GitOrigin-RevId: da5e05f82b5112474d51f9eab78318b1d6432742
上级 e507228e
......@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}
MegRay::DType get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}
MegRay::Backend get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
......
......@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i];
}
data_size *= tensor.dtype().size();
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size,
get_megray_dtype(tensor.dtype()),
1, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
if (m_is_grad) {
......@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() {
for (size_t i = 0; i < ishp.ndim; i++) {
data_size *= ishp[i];
}
data_size *= tensor.dtype().size();
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx);
auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size,
get_megray_dtype(tensor.dtype()),
0, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
}
......
......@@ -14,6 +14,33 @@
using namespace mgb;
using namespace opr;
MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
switch(dtype.enumv()) {
case DTypeEnum::Int8:
return MegRay::DType::MEGRAY_INT8;
case DTypeEnum::Int32:
return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16;
#endif
default:
mgb_throw(MegBrainError, "bad CollectiveComm dtype");
}
}
MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
if (backend == "nccl") {
return MegRay::MEGRAY_NCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
}
bool MegRayCommBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);
......
......@@ -13,13 +13,16 @@
#include <mutex>
#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
namespace mgb {
namespace opr {
MegRay::DType get_megray_dtype(megdnn::DType);
MegRay::Backend get_megray_backend(const std::string& backend);
/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册