diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 66e3f2ad330fb69d8be86426581afa50440ad96f..ee0d837cee438b4d9bd1007d8f48f78a293593de 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -37,11 +37,11 @@ class SupportedChecker { public: SupportedChecker() = default; virtual ~SupportedChecker() = default; - virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); } - virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); } diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc index 5b5bf7e4fcb63c2ce23eee8c99c0824b5a52b03f..cfa4e4234241bfb1e70683499bd0422310738f0a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph return nullptr; } auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); - if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { - return node; - } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { + if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { + return nullptr; + } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { auto builder = std::make_shared(kernel_builder_info); builder->SetKernelType(AICPU_KERNEL); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); @@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" << node->DebugString() << "]"; } - return node; + return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 9abef8fa703d1f55adf16f70992bdc1b7c3c8391..95bcb9f210b0c27dcc1f726423f84b36a29440ca 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod auto indices_const = CreateValueNode(new_cnode); new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { + if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { MS_LOG(INFO) << "split topk failed, check to aicpu."; return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc index 16517187032c07989fb356bfeb42d9729431926d..e45fc2637fe0cc3e1a5e92d4400a4c799f3edabc 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { + if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { std::vector inputs = {NewValueNode(new_fusion_transdata), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 6cc68457e5d03a33cc2b9a6510e7b6aa7f14881b..09ea32becba9d75d7ef1b44d8d111d66506a6188 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } + +bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 10ae5282e0a4a7a8e798eabb39439a1b4b114cbf..bab867a3ef8f0b360c394a69dffefd16ad4f3504 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm { static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsScalarInput(const CNodePtr &cnode, size_t index); + static bool IsScalarOutput(const CNodePtr &cnode, size_t index); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ff2ba05c841e1dbc1e65619c3fcb970f522cef51..b2771f4b9b7896e8ddce4000e0b4d50ed3fc7096 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other"; // some size const size_t kShape4dDims = 4; +const size_t kShape2dDims = 2; const size_t kShape5dDims = 5; +const size_t kShape1dDims = 1; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; const int kParameterDataTensorMask = 0; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 4cee3577ed4fa1213c9d8af720d21fde38b40eac..b09268aa662f44d31397c9e4045f769356aa99e6 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } }; // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 8bb9de7c7d47ffa5220e469d05e7a5771652741a..98dc9e9efc37ac0b6e736a670f79c9810d0dea7c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } };