From 7d07e17f5ad261d0995d02f1d7978237fb1e5898 Mon Sep 17 00:00:00 2001 From: caifubi Date: Thu, 14 May 2020 14:08:26 +0800 Subject: [PATCH] Support Common Hccl Op 1.Support Broadcast op 2.Support communication op as graph output 3.Optimize Communication op memory alocation 4.support hccl multi-group --- cmake/dependency_graphengine.cmake | 5 +- cmake/package.cmake | 2 +- graphengine | 2 +- .../device/ascend/ascend_kernel_runtime.cc | 3 +- .../profiling/reporter/graph_desc_reporter.h | 1 + .../device/ascend/tasksink/runtime_utils.cc | 8 +- mindspore/ccsrc/device/kernel_runtime.cc | 80 ++++++++++++++----- mindspore/ccsrc/device/kernel_runtime.h | 8 +- mindspore/ccsrc/kernel/hccl/hccl_kernel.cc | 3 +- mindspore/ccsrc/kernel/hccl/hccl_kernel.h | 1 + mindspore/ccsrc/kernel/hccl/hcom_util.cc | 13 ++- mindspore/ccsrc/kernel/hccl/hcom_util.h | 2 + .../ascend/enhancer/add_memcpy_async.cc | 3 +- .../format_type/deal_ref_trans_and_cast.cc | 16 ++++ 14 files changed, 111 insertions(+), 36 deletions(-) diff --git a/cmake/dependency_graphengine.cmake b/cmake/dependency_graphengine.cmake index 533f9f824..991eb2a24 100644 --- a/cmake/dependency_graphengine.cmake +++ b/cmake/dependency_graphengine.cmake @@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) -# for CPU/GPU mode, find c_sec and slog from local prebuild +# for CPU/GPU mode, find slog from local prebuild if (NOT ENABLE_D) set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) - find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH}) find_library(slog libslog.so ${GE_PREBUILD_PATH}) elseif (DEFINED ENV{D_LINK_PATH}) set(GE_LIB_PATH $ENV{D_LINK_PATH}) @@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH}) message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") endif() set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) - find_library(c_sec libc_sec.so ${GE_LIB_PATH}) find_library(slog libslog.so ${GE_LIB_PATH}) find_library(mmpa libmmpa.so ${GE_LIB_PATH}) find_library(runtime libruntime.so ${GE_LIB_PATH}) diff --git a/cmake/package.cmake b/cmake/package.cmake index 338ece1f4..875ba5217 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -153,7 +153,7 @@ if (NOT ENABLE_GE) FILES ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so - ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libc_sec.so + ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) diff --git a/graphengine b/graphengine index 995b6dadc..579dcb75a 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 995b6dadc0fbbe4b80a08196886a53a18bffa60e +Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23 diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index d78c2f920..fa5fb6d67 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); if (!status) { - MS_LOG(ERROR) << "load task failed"; - return false; + MS_LOG(EXCEPTION) << "Load Task Failed"; } if (ProfilingManager::GetInstance().IsProfiling()) { auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h index 0365472ae..3c48a90ef 100644 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h +++ b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h @@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter { public: GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector cnode_list) : DescReporter(device_id, file_name, std::move(cnode_list)) {} + ~GraphDescReporter() override = default; void ReportData() override; }; } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc index 7a95ddc84..20084c092 100644 --- a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc +++ b/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc @@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast(task_info->input_data_addr()), static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->root_id()), nullptr, stream); + static_cast(task_info->root_id()), task_info->group().c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); return false; @@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast(task_info->input_data_addr()), reinterpret_cast(task_info->output_data_addr()), static_cast(task_info->count()), - static_cast(task_info->data_type()), nullptr, stream); + static_cast(task_info->data_type()), task_info->group().c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; return false; @@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast(task_info->input_data_addr()), reinterpret_cast(task_info->output_data_addr()), static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), nullptr, stream); + static_cast(task_info->op_type()), task_info->group().c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; return false; @@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast(task_info->input_data_addr()), reinterpret_cast(task_info->output_data_addr()), static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), nullptr, stream); + static_cast(task_info->op_type()), task_info->group().c_str(), stream); if (ret != HCCL_SUCCESS) { MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; return false; diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 8daf309cb..42f56af8d 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -15,6 +15,7 @@ */ #include "device/kernel_runtime.h" +#include #include #include #include @@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { mem_manager_->ResetDynamicMemory(); AssignStaticMemory(graph); AssignDynamicMemory(graph); - UpdateRefNodeOutputMem(graph); } void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); - // assign memory for input nodes RunOpAssignInputMemory(input_tensors, graph); AssignStaticMemoryValueNode(graph); for (const auto &cnode : graph->execution_order()) { - // assign memory for output nodes RunOpAssignOutputMemory(cnode); - // assign memory for workspace RunOpAssignWorkSpaceMemory(cnode); } UpdateRefNodeOutputMem(graph); @@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + std::vector non_communication_op; + // Assign Communicate Op Memory firstly. for (const auto &node : nodes) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); MS_EXCEPTION_IF_NULL(item_with_index.first); if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { continue; } + if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { + AssignCommunicationNodeMem(kStaticMem, item_with_index.first); + } else { + non_communication_op.emplace_back(item_with_index); + } + } + + for (const auto &item_with_index : non_communication_op) { AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); } } @@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { } } +void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { + AssignCommunicationNodeInputMem(node); + AssignCommunicationNodeOutputMem(flag, node); +} + void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); @@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); size_t total_size = 0; + size_t output_index = 0; std::vector align_size_list; for (uint64_t mem_size : output_sizes) { + if (AnfAlgo::OutputAddrExist(node, output_index++)) { + MS_LOG(INFO) << "communication op addr exist"; + continue; + } if (context_ptr->enable_hccl()) { mem_size = mem_manager_->GetCommonAlignSize(mem_size); } @@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr } } -void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { +DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.size() <= index) { + MS_LOG(EXCEPTION) << "Previous node output size < node index"; + } + std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); + auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); + auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); + AnfAlgo::SetOutputAddr(address, index, anf_node.get()); + return address; +} + +void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(node); @@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { size_t total_size = 0; std::vector> addr_size; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { - auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i); - MS_EXCEPTION_IF_NULL(address); - auto mem_size = address->size(); - if (context_ptr->enable_hccl()) { - mem_size = mem_manager_->GetCommonAlignSize(mem_size); + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto input_node = input_node_with_index.first; + DeviceAddressPtr address = nullptr; + if (input_node->isa()) { + address = PreAssignCNodeMemory(input_node, input_node_with_index.second); + } else { + MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; } + MS_EXCEPTION_IF_NULL(address); + auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); total_size += mem_size; addr_size.emplace_back(address.get(), mem_size); } @@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) { void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(mem_manager_); - if (AnfAlgo::IsCommunicationOp(node)) { - UpdateCommunicationOpInputMem(node); - AssignCommunicationNodeOutputMem(flag, node); - return; - } if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { MS_LOG(INFO) << "GetNext disable mem_reuse"; flag = kDynamicMem; @@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { mem_manager_->MallocReusedDynamicMem(graph); mem_flag = kReuseDynamicMem; } - auto &kernels = graph->execution_order(); - for (auto &kernel : kernels) { - AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts); - AssignWorkSpaceMem(mem_flag, kernel); + auto &execution_nodes = graph->execution_order(); + std::vector compute_nodes; + // communication nodes first + for (auto &node : execution_nodes) { + if (AnfAlgo::IsCommunicationOp(node)) { + // skip if the memory is already alocated + AssignCommunicationNodeMem(mem_flag, node); + } else { + compute_nodes.emplace_back(node); + } + } + + // then compute nodes + for (auto &node : compute_nodes) { + AssignNodeOutputMem(mem_flag, node, kGetAllOuts); + AssignWorkSpaceMem(mem_flag, node); } } diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index b15cb31e1..bf44698b8 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -73,9 +73,12 @@ class KernelRuntime { void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); void AssignReuseWorkSpaceMem(const AnfNodePtr &node); - void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); + void UpdateRefNodeOutputMem(const session::KernelGraph *graph); - void UpdateCommunicationOpInputMem(const AnfNodePtr &node); + + void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); + void AssignCommunicationNodeInputMem(const AnfNodePtr &node); + void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); #ifdef ENABLE_DUMP_E2E bool SetDumpConf(); #endif @@ -91,6 +94,7 @@ class KernelRuntime { void RunOpAssignOutputMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); + DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); protected: uint32_t device_id_{0}; diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc index 5421b301a..493998c16 100644 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc @@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { return false; } } + HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); anf_node_ = anf_node; return true; } @@ -147,7 +148,7 @@ std::vector HcclKernel::GenTask(const std::vector &inpu HcclTaskInfoPtr task_info_ptr = std::make_shared( stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr, - hccl_count_, root_id_, op_type_, data_type, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, + hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, RuntimeUtils::HcomDistribute); MS_EXCEPTION_IF_NULL(task_info_ptr); return {task_info_ptr}; diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel.h index b6c1fcfff..72e202591 100644 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel.h @@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod { mutable std::vector workspace_size_list_; AnfNodePtr anf_node_; std::string op_name_; + std::string group_; }; using HcclKernelCreater = std::function()>; diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc index 5665475c8..f2d35878d 100644 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ b/mindspore/ccsrc/kernel/hccl/hcom_util.cc @@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); MS_EXCEPTION_IF_NULL(primitive); if (primitive->GetAttr("root_rank") != nullptr) { - *root_id = GetValue>(primitive->GetAttr("root_rank"))[0]; + *root_id = (uint32_t)GetValue(primitive->GetAttr("root_rank")); } else { MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; return false; } return true; } + +void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + auto attr = primitive->GetAttr("group"); + if (attr != nullptr) { + *group = GetValue(attr); + } else { + MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; + } +} } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.h b/mindspore/ccsrc/kernel/hccl/hcom_util.h index e0524b806..dc9596cf5 100644 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.h +++ b/mindspore/ccsrc/kernel/hccl/hcom_util.h @@ -23,6 +23,7 @@ #include #include "ir/dtype.h" #include "hccl/base.h" +#include "utils/contract.h" namespace mindspore { using std::map; @@ -61,6 +62,7 @@ class HcomUtil { const vector> &shape_list, uint64_t *total_count); static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type); static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); + static void GetHcomGroup(NotNull anf_node, NotNull group); }; } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc index 446a67d2f..62316f4f6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc @@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A return nullptr; } auto cnode = node->cast(); - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (op_name != kAllReduceOpName && op_name != kAllGatherOpName && op_name != kReduceScatterOpName) { + if (!AnfAlgo::IsCommunicationOp(node)) { return nullptr; } return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 83a44029a..a9196c5c4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { return VectorRef({V, Xs}); } +void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { + auto input_size = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); + auto input_node = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(input_node); + MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); + AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); + } + } +} + const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { if (node == nullptr || !node->isa()) { @@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A if (!AnfAlgo::IsRealCNodeKernel(cnode)) { return nullptr; } + + DealBroadCastAsRef(graph, cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); if (op_info == nullptr || !op_info->is_ref()) { -- GitLab