diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index a0ff6e0b8aad5e7b553e2c08dc19af1bc1377865..aac9166c646ecc31660b828c8ffb45381b02f12a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "pre_activate/ascend/ir_fission/topk_split.h" +#include #include #include #include @@ -102,6 +103,11 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod // set value node as topk's input auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); + auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); + if (input_names_vec.size() < kTopkIndexK + 1) { + MS_LOG(INFO) << "The input k of topk has been converted to attr"; + return nullptr; + } // Copy a new node to check supported. std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 1a3440e78064c6990412bf373e8040b7ba7badc9..1c0454a56d8e966e32042e2ba272bad8a8ee010b 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -19,6 +19,7 @@ #include "device/kernel_info.h" #include "pre_activate/pass/convert_const_input_to_attr.h" #include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" #define private public #define protected public #include "pre_activate/ascend/ir_fission/topk_split.h" @@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon { TestHWTopKSplit() : get_py_fun_("gtest_input.pre_activate.topk_split_test", true) {} ~TestHWTopKSplit() override = default; + CNodePtr GetTopkCNodeFromKernelGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto ret = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto make_tuple = ret->input(1); + MS_EXCEPTION_IF_NULL(make_tuple); + auto tuple_getitem = make_tuple->cast()->input(1); + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto topk = tuple_getitem->cast()->input(1); + MS_EXCEPTION_IF_NULL(topk); + auto topk_cnode = topk->cast(); + MS_EXCEPTION_IF_NULL(topk_cnode); + return topk_cnode; + } + UT::PyFuncGraphFetcher get_py_fun_; }; @@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAiCoreSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } }; // namespace opt @@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) { optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); - auto ret = new_graph->get_return(); - EXPECT_NE(ret, nullptr); - auto make_tuple = ret->input(1); - EXPECT_NE(make_tuple, nullptr); - auto tuple_getitem = make_tuple->cast()->input(1); - EXPECT_NE(tuple_getitem, nullptr); - auto topk = tuple_getitem->cast()->input(1); - auto topk_cnode = topk->cast(); + auto topk_cnode = GetTopkCNodeFromKernelGraph(new_graph); EXPECT_EQ(topk_cnode->inputs().size(), 3); EXPECT_TRUE(topk_cnode->input(2)->isa()); auto value_node = topk_cnode->input(2)->cast(); @@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) { EXPECT_EQ(tensor->shape().size(), 1); EXPECT_EQ(tensor->shape()[0], 4); } + +TEST_F(TestHWTopKSplit, test_topk_no_split) { + /* + * def before(input): + * topk = TopKSplit(input) + * output = tuple_getitem(topk, 0) + * return output + */ + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); + std::vector shp{4, 4}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kernel_graph = GetKernelGraph(g, args_spec_list); + + CNodePtr topk_cnode = GetTopkCNodeFromKernelGraph(kernel_graph); + EXPECT_EQ(topk_cnode->inputs().size(), 3); + auto input_names_vec = AnfAlgo::GetNodeAttr>(topk_cnode, kAttrInputNames); + EXPECT_EQ(input_names_vec.size(), 2); + std::unordered_set attr_index{1}; + ConstInputToAttr(topk_cnode, attr_index); + EXPECT_EQ(topk_cnode->inputs().size(), 2); + input_names_vec = AnfAlgo::GetNodeAttr>(topk_cnode, kAttrInputNames); + EXPECT_EQ(input_names_vec.size(), 1); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + auto topk_split = std::make_shared(); + topk_split->supported_checker_ = std::make_shared(); + pm->AddPass(topk_split); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); + EXPECT_EQ(topk_cnode, GetTopkCNodeFromKernelGraph(new_graph)); +} } // namespace opt } // namespace mindspore