提交 7d07e17f 编写于 作者: C caifubi

Support Common Hccl Op

1.Support Broadcast op
2.Support communication op as graph output
3.Optimize Communication op memory alocation
4.support hccl multi-group
上级 93f6fc0a
......@@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
# for CPU/GPU mode, find c_sec and slog from local prebuild
# for CPU/GPU mode, find slog from local prebuild
if (NOT ENABLE_D)
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH})
find_library(slog libslog.so ${GE_PREBUILD_PATH})
elseif (DEFINED ENV{D_LINK_PATH})
set(GE_LIB_PATH $ENV{D_LINK_PATH})
......@@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH})
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
find_library(slog libslog.so ${GE_LIB_PATH})
find_library(mmpa libmmpa.so ${GE_LIB_PATH})
find_library(runtime libruntime.so ${GE_LIB_PATH})
......
......@@ -153,7 +153,7 @@ if (NOT ENABLE_GE)
FILES
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libc_sec.so
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
......
graphengine @ 579dcb75
Subproject commit 995b6dadc0fbbe4b80a08196886a53a18bffa60e
Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23
......@@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first,
model_iter->second, listener);
if (!status) {
MS_LOG(ERROR) << "load task failed";
return false;
MS_LOG(EXCEPTION) << "Load Task Failed";
}
if (ProfilingManager::GetInstance().IsProfiling()) {
auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
......
......@@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter {
public:
GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list)
: DescReporter(device_id, file_name, std::move(cnode_list)) {}
~GraphDescReporter() override = default;
void ReportData() override;
};
} // namespace ascend
......
......@@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<u32>(task_info->root_id()), nullptr, stream);
static_cast<u32>(task_info->root_id()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
return false;
......@@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
static_cast<hcclDataType_t>(task_info->data_type()), nullptr, stream);
static_cast<hcclDataType_t>(task_info->data_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
return false;
......@@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream);
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
return false;
......@@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
reinterpret_cast<void *>(task_info->output_data_addr()),
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream);
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
if (ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
return false;
......
......@@ -15,6 +15,7 @@
*/
#include "device/kernel_runtime.h"
#include <vector>
#include <utility>
#include <numeric>
#include <functional>
......@@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_->ResetDynamicMemory();
AssignStaticMemory(graph);
AssignDynamicMemory(graph);
UpdateRefNodeOutputMem(graph);
}
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
// assign memory for input nodes
RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph);
for (const auto &cnode : graph->execution_order()) {
// assign memory for output nodes
RunOpAssignOutputMemory(cnode);
// assign memory for workspace
RunOpAssignWorkSpaceMemory(cnode);
}
UpdateRefNodeOutputMem(graph);
......@@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
std::vector<session::KernelWithIndex> non_communication_op;
// Assign Communicate Op Memory firstly.
for (const auto &node : nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue;
}
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
} else {
non_communication_op.emplace_back(item_with_index);
}
}
for (const auto &item_with_index : non_communication_op) {
AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
}
}
......@@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
}
}
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(node);
AssignCommunicationNodeOutputMem(flag, node);
}
void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_);
......@@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
size_t total_size = 0;
size_t output_index = 0;
std::vector<size_t> align_size_list;
for (uint64_t mem_size : output_sizes) {
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
MS_LOG(INFO) << "communication op addr exist";
continue;
}
if (context_ptr->enable_hccl()) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
}
......@@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
}
}
void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
auto output_sizes = kernel_mod->GetOutputSizeList();
if (output_sizes.size() <= index) {
MS_LOG(EXCEPTION) << "Previous node output size < node index";
}
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
return address;
}
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(node);
......@@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
size_t total_size = 0;
std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i);
MS_EXCEPTION_IF_NULL(address);
auto mem_size = address->size();
if (context_ptr->enable_hccl()) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto input_node = input_node_with_index.first;
DeviceAddressPtr address = nullptr;
if (input_node->isa<CNode>()) {
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
} else {
MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
}
MS_EXCEPTION_IF_NULL(address);
auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
total_size += mem_size;
addr_size.emplace_back(address.get(), mem_size);
}
......@@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(mem_manager_);
if (AnfAlgo::IsCommunicationOp(node)) {
UpdateCommunicationOpInputMem(node);
AssignCommunicationNodeOutputMem(flag, node);
return;
}
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
MS_LOG(INFO) << "GetNext disable mem_reuse";
flag = kDynamicMem;
......@@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
mem_manager_->MallocReusedDynamicMem(graph);
mem_flag = kReuseDynamicMem;
}
auto &kernels = graph->execution_order();
for (auto &kernel : kernels) {
AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts);
AssignWorkSpaceMem(mem_flag, kernel);
auto &execution_nodes = graph->execution_order();
std::vector<CNodePtr> compute_nodes;
// communication nodes first
for (auto &node : execution_nodes) {
if (AnfAlgo::IsCommunicationOp(node)) {
// skip if the memory is already alocated
AssignCommunicationNodeMem(mem_flag, node);
} else {
compute_nodes.emplace_back(node);
}
}
// then compute nodes
for (auto &node : compute_nodes) {
AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
AssignWorkSpaceMem(mem_flag, node);
}
}
......
......@@ -73,9 +73,12 @@ class KernelRuntime {
void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index);
void AssignWorkSpaceMem(int flag, const AnfNodePtr &node);
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
void UpdateCommunicationOpInputMem(const AnfNodePtr &node);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(const AnfNodePtr &node);
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
#ifdef ENABLE_DUMP_E2E
bool SetDumpConf();
#endif
......@@ -91,6 +94,7 @@ class KernelRuntime {
void RunOpAssignOutputMemory(const AnfNodePtr &kernel);
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
protected:
uint32_t device_id_{0};
......
......@@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
return false;
}
}
HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_));
anf_node_ = anf_node;
return true;
}
......@@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>(
stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr,
hccl_count_, root_id_, op_type_, data_type, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
RuntimeUtils::HcomDistribute);
MS_EXCEPTION_IF_NULL(task_info_ptr);
return {task_info_ptr};
......
......@@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod {
mutable std::vector<size_t> workspace_size_list_;
AnfNodePtr anf_node_;
std::string op_name_;
std::string group_;
};
using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>;
......
......@@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("root_rank") != nullptr) {
*root_id = GetValue<const vector<uint32_t>>(primitive->GetAttr("root_rank"))[0];
*root_id = (uint32_t)GetValue<int>(primitive->GetAttr("root_rank"));
} else {
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!";
return false;
}
return true;
}
void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
auto attr = primitive->GetAttr("group");
if (attr != nullptr) {
*group = GetValue<std::string>(attr);
} else {
MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed";
}
}
} // namespace mindspore
......@@ -23,6 +23,7 @@
#include <memory>
#include "ir/dtype.h"
#include "hccl/base.h"
#include "utils/contract.h"
namespace mindspore {
using std::map;
......@@ -61,6 +62,7 @@ class HcomUtil {
const vector<vector<size_t>> &shape_list, uint64_t *total_count);
static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type);
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
};
} // namespace mindspore
......
......@@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != kAllReduceOpName && op_name != kAllGatherOpName && op_name != kReduceScatterOpName) {
if (!AnfAlgo::IsCommunicationOp(node)) {
return nullptr;
}
return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode);
......
......@@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const {
return VectorRef({V, Xs});
}
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
}
}
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
......@@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
return nullptr;
}
DealBroadCastAsRef(graph, cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
if (op_info == nullptr || !op_info->is_ref()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册