提交 beb714d2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1911 add a function to charge the node input and output is a scalar

Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalar
......@@ -37,11 +37,11 @@ class SupportedChecker {
public:
SupportedChecker() = default;
virtual ~SupportedChecker() = default;
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
}
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
}
......
......@@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
return node;
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) {
return nullptr;
} else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
......@@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]";
}
return node;
return nullptr;
}
} // namespace opt
} // namespace mindspore
......@@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
MS_LOG(INFO) << "split topk failed, check to aicpu.";
return nullptr;
}
......
......@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
......
......@@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}
bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}
bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}
} // namespace session
} // namespace mindspore
......@@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm {
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;
......
......@@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other";
// some size
const size_t kShape4dDims = 4;
const size_t kShape2dDims = 2;
const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16;
const size_t kMemAlignSize = 512;
const int kParameterDataTensorMask = 0;
......
......@@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt
......
......@@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册