提交 beb714d2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1911 add a function to charge the node input and output is a scalar

Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalar
...@@ -37,11 +37,11 @@ class SupportedChecker { ...@@ -37,11 +37,11 @@ class SupportedChecker {
public: public:
SupportedChecker() = default; SupportedChecker() = default;
virtual ~SupportedChecker() = default; virtual ~SupportedChecker() = default;
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) { const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
} }
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) { const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
} }
......
...@@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph ...@@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return nullptr; return nullptr;
} }
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) {
return node; return nullptr;
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info); auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL); builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
...@@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph ...@@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]"; << node->DebugString() << "]";
} }
return node; return nullptr;
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod ...@@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto indices_const = CreateValueNode(new_cnode); auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const); new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_); MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
MS_LOG(INFO) << "split topk failed, check to aicpu."; MS_LOG(INFO) << "split topk failed, check to aicpu.";
return nullptr; return nullptr;
} }
......
...@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap ...@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName); auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata), std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs); auto new_node = func_graph->NewCNode(inputs);
......
...@@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { ...@@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
} }
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
} }
bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}
bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
...@@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm { ...@@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm {
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;
......
...@@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other"; ...@@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other";
// some size // some size
const size_t kShape4dDims = 4; const size_t kShape4dDims = 4;
const size_t kShape2dDims = 2;
const size_t kShape5dDims = 5; const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16; const size_t kCubeSize = 16;
const size_t kMemAlignSize = 512; const size_t kMemAlignSize = 512;
const int kParameterDataTensorMask = 0; const int kParameterDataTensorMask = 0;
......
...@@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker { ...@@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker {
public: public:
MockSupportedChecker() = default; MockSupportedChecker() = default;
~MockSupportedChecker() override = default; ~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true; return true;
} }
}; // namespace opt }; // namespace opt
......
...@@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker { ...@@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker {
public: public:
MockSupportedChecker() = default; MockSupportedChecker() = default;
~MockSupportedChecker() override = default; ~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true; return true;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册