提交 7d07e17f 编写于 作者: C caifubi

Support Common Hccl Op

1.Support Broadcast op
2.Support communication op as graph output
3.Optimize Communication op memory alocation
4.support hccl multi-group
上级 93f6fc0a
...@@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) ...@@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
# for CPU/GPU mode, find c_sec and slog from local prebuild # for CPU/GPU mode, find slog from local prebuild
if (NOT ENABLE_D) if (NOT ENABLE_D)
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH})
find_library(slog libslog.so ${GE_PREBUILD_PATH}) find_library(slog libslog.so ${GE_PREBUILD_PATH})
elseif (DEFINED ENV{D_LINK_PATH}) elseif (DEFINED ENV{D_LINK_PATH})
set(GE_LIB_PATH $ENV{D_LINK_PATH}) set(GE_LIB_PATH $ENV{D_LINK_PATH})
...@@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH}) ...@@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH})
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif() endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
find_library(slog libslog.so ${GE_LIB_PATH}) find_library(slog libslog.so ${GE_LIB_PATH})
find_library(mmpa libmmpa.so ${GE_LIB_PATH}) find_library(mmpa libmmpa.so ${GE_LIB_PATH})
find_library(runtime libruntime.so ${GE_LIB_PATH}) find_library(runtime libruntime.so ${GE_LIB_PATH})
......
...@@ -153,7 +153,7 @@ if (NOT ENABLE_GE) ...@@ -153,7 +153,7 @@ if (NOT ENABLE_GE)
FILES FILES
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libc_sec.so ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
DESTINATION ${INSTALL_LIB_DIR} DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore COMPONENT mindspore
) )
......
graphengine @ 579dcb75
Subproject commit 995b6dadc0fbbe4b80a08196886a53a18bffa60e Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23
...@@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { ...@@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first,
model_iter->second, listener); model_iter->second, listener);
if (!status) { if (!status) {
MS_LOG(ERROR) << "load task failed"; MS_LOG(EXCEPTION) << "Load Task Failed";
return false;
} }
if (ProfilingManager::GetInstance().IsProfiling()) { if (ProfilingManager::GetInstance().IsProfiling()) {
auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
......
...@@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter { ...@@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter {
public: public:
GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list) GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list)
: DescReporter(device_id, file_name, std::move(cnode_list)) {} : DescReporter(device_id, file_name, std::move(cnode_list)) {}
~GraphDescReporter() override = default;
void ReportData() override; void ReportData() override;
}; };
} // namespace ascend } // namespace ascend
......
...@@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info ...@@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<u32>(task_info->root_id()), nullptr, stream); static_cast<u32>(task_info->root_id()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret); MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
return false; return false;
...@@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info ...@@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
static_cast<hcclDataType_t>(task_info->data_type()), nullptr, stream); static_cast<hcclDataType_t>(task_info->data_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
return false; return false;
...@@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info ...@@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()), reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
static_cast<hcclDataType_t>(task_info->data_type()), static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream); static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
return false; return false;
...@@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info ...@@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()), ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()), reinterpret_cast<void *>(task_info->output_data_addr()),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream); static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) { if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
return false; return false;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
#include "device/kernel_runtime.h" #include "device/kernel_runtime.h"
#include <vector>
#include <utility> #include <utility>
#include <numeric> #include <numeric>
#include <functional> #include <functional>
...@@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { ...@@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_->ResetDynamicMemory(); mem_manager_->ResetDynamicMemory();
AssignStaticMemory(graph); AssignStaticMemory(graph);
AssignDynamicMemory(graph); AssignDynamicMemory(graph);
UpdateRefNodeOutputMem(graph); UpdateRefNodeOutputMem(graph);
} }
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph) { session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
// assign memory for input nodes
RunOpAssignInputMemory(input_tensors, graph); RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph); AssignStaticMemoryValueNode(graph);
for (const auto &cnode : graph->execution_order()) { for (const auto &cnode : graph->execution_order()) {
// assign memory for output nodes
RunOpAssignOutputMemory(cnode); RunOpAssignOutputMemory(cnode);
// assign memory for workspace
RunOpAssignWorkSpaceMemory(cnode); RunOpAssignWorkSpaceMemory(cnode);
} }
UpdateRefNodeOutputMem(graph); UpdateRefNodeOutputMem(graph);
...@@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { ...@@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
std::vector<session::KernelWithIndex> non_communication_op;
// Assign Communicate Op Memory firstly.
for (const auto &node : nodes) { for (const auto &node : nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first); MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue; continue;
} }
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
} else {
non_communication_op.emplace_back(item_with_index);
}
}
for (const auto &item_with_index : non_communication_op) {
AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
} }
} }
...@@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { ...@@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
} }
} }
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(node);
AssignCommunicationNodeOutputMem(flag, node);
}
void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
...@@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr ...@@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
size_t total_size = 0; size_t total_size = 0;
size_t output_index = 0;
std::vector<size_t> align_size_list; std::vector<size_t> align_size_list;
for (uint64_t mem_size : output_sizes) { for (uint64_t mem_size : output_sizes) {
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
MS_LOG(INFO) << "communication op addr exist";
continue;
}
if (context_ptr->enable_hccl()) { if (context_ptr->enable_hccl()) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size); mem_size = mem_manager_->GetCommonAlignSize(mem_size);
} }
...@@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr ...@@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
} }
} }
void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
auto output_sizes = kernel_mod->GetOutputSizeList();
if (output_sizes.size() <= index) {
MS_LOG(EXCEPTION) << "Previous node output size < node index";
}
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
return address;
}
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
...@@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { ...@@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
size_t total_size = 0; size_t total_size = 0;
std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size; std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i); auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
MS_EXCEPTION_IF_NULL(address); auto input_node = input_node_with_index.first;
auto mem_size = address->size(); DeviceAddressPtr address = nullptr;
if (context_ptr->enable_hccl()) { if (input_node->isa<CNode>()) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size); address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
} else {
MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
} }
MS_EXCEPTION_IF_NULL(address);
auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
total_size += mem_size; total_size += mem_size;
addr_size.emplace_back(address.get(), mem_size); addr_size.emplace_back(address.get(), mem_size);
} }
...@@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { ...@@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
if (AnfAlgo::IsCommunicationOp(node)) {
UpdateCommunicationOpInputMem(node);
AssignCommunicationNodeOutputMem(flag, node);
return;
}
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
MS_LOG(INFO) << "GetNext disable mem_reuse"; MS_LOG(INFO) << "GetNext disable mem_reuse";
flag = kDynamicMem; flag = kDynamicMem;
...@@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { ...@@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
mem_manager_->MallocReusedDynamicMem(graph); mem_manager_->MallocReusedDynamicMem(graph);
mem_flag = kReuseDynamicMem; mem_flag = kReuseDynamicMem;
} }
auto &kernels = graph->execution_order(); auto &execution_nodes = graph->execution_order();
for (auto &kernel : kernels) { std::vector<CNodePtr> compute_nodes;
AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts); // communication nodes first
AssignWorkSpaceMem(mem_flag, kernel); for (auto &node : execution_nodes) {
if (AnfAlgo::IsCommunicationOp(node)) {
// skip if the memory is already alocated
AssignCommunicationNodeMem(mem_flag, node);
} else {
compute_nodes.emplace_back(node);
}
}
// then compute nodes
for (auto &node : compute_nodes) {
AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
AssignWorkSpaceMem(mem_flag, node);
} }
} }
......
...@@ -73,9 +73,12 @@ class KernelRuntime { ...@@ -73,9 +73,12 @@ class KernelRuntime {
void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index);
void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); void AssignWorkSpaceMem(int flag, const AnfNodePtr &node);
void AssignReuseWorkSpaceMem(const AnfNodePtr &node); void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void UpdateRefNodeOutputMem(const session::KernelGraph *graph); void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
void UpdateCommunicationOpInputMem(const AnfNodePtr &node);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(const AnfNodePtr &node);
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool SetDumpConf(); bool SetDumpConf();
#endif #endif
...@@ -91,6 +94,7 @@ class KernelRuntime { ...@@ -91,6 +94,7 @@ class KernelRuntime {
void RunOpAssignOutputMemory(const AnfNodePtr &kernel); void RunOpAssignOutputMemory(const AnfNodePtr &kernel);
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
protected: protected:
uint32_t device_id_{0}; uint32_t device_id_{0};
......
...@@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { ...@@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
return false; return false;
} }
} }
HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_));
anf_node_ = anf_node; anf_node_ = anf_node;
return true; return true;
} }
...@@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu ...@@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>( HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>(
stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr, stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr,
hccl_count_, root_id_, op_type_, data_type, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
RuntimeUtils::HcomDistribute); RuntimeUtils::HcomDistribute);
MS_EXCEPTION_IF_NULL(task_info_ptr); MS_EXCEPTION_IF_NULL(task_info_ptr);
return {task_info_ptr}; return {task_info_ptr};
......
...@@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod { ...@@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod {
mutable std::vector<size_t> workspace_size_list_; mutable std::vector<size_t> workspace_size_list_;
AnfNodePtr anf_node_; AnfNodePtr anf_node_;
std::string op_name_; std::string op_name_;
std::string group_;
}; };
using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>; using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>;
......
...@@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { ...@@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("root_rank") != nullptr) { if (primitive->GetAttr("root_rank") != nullptr) {
*root_id = GetValue<const vector<uint32_t>>(primitive->GetAttr("root_rank"))[0]; *root_id = (uint32_t)GetValue<int>(primitive->GetAttr("root_rank"));
} else { } else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!";
return false; return false;
} }
return true; return true;
} }
void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
auto attr = primitive->GetAttr("group");
if (attr != nullptr) {
*group = GetValue<std::string>(attr);
} else {
MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed";
}
}
} // namespace mindspore } // namespace mindspore
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <memory> #include <memory>
#include "ir/dtype.h" #include "ir/dtype.h"
#include "hccl/base.h" #include "hccl/base.h"
#include "utils/contract.h"
namespace mindspore { namespace mindspore {
using std::map; using std::map;
...@@ -61,6 +62,7 @@ class HcomUtil { ...@@ -61,6 +62,7 @@ class HcomUtil {
const vector<vector<size_t>> &shape_list, uint64_t *total_count); const vector<vector<size_t>> &shape_list, uint64_t *total_count);
static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type);
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A ...@@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A
return nullptr; return nullptr;
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto op_name = AnfAlgo::GetCNodeName(cnode); if (!AnfAlgo::IsCommunicationOp(node)) {
if (op_name != kAllReduceOpName && op_name != kAllGatherOpName && op_name != kReduceScatterOpName) {
return nullptr; return nullptr;
} }
return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode);
......
...@@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { ...@@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const {
return VectorRef({V, Xs}); return VectorRef({V, Xs});
} }
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
}
}
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const { const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
...@@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A ...@@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
if (!AnfAlgo::IsRealCNodeKernel(cnode)) { if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
return nullptr; return nullptr;
} }
DealBroadCastAsRef(graph, cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode); auto op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
if (op_info == nullptr || !op_info->is_ref()) { if (op_info == nullptr || !op_info->is_ref()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册