diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 1cf1b85e9b03515e75a9f6c28ddc8db2667c2739..c809618b339851c18cab672eeb58d2d91b622f81 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -614,11 +614,6 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> parse_info_list; - if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr(kernel_node, "sorted") == false) { - MS_LOG(INFO) << "will select aicpu topk."; - return; - } - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); if (op_info_ptr == nullptr) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 64f5ba0cf6c6782f8d2ecc4c49e6cd702a82fd40..a0ff6e0b8aad5e7b553e2c08dc19af1bc1377865 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -77,6 +77,9 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetKernelType(TBE_KERNEL); + builder.SetFusionType(kernel::OPAQUE); + builder.SetProcessor(kernel::AICORE); builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); @@ -129,10 +132,12 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { + MS_LOG(INFO) << "split topk failed, check to aicpu."; return nullptr; } if (kernel_graph != nullptr) { + MS_LOG(INFO) << "split topk success. use tbe aicore."; kernel_graph->AddValueNodeToGraph(indices_const); }