提交 7e2e7ad5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1367 modify topk split pass

Merge pull request !1367 from jjfeing/master
...@@ -614,11 +614,6 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke ...@@ -614,11 +614,6 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted") == false) {
MS_LOG(INFO) << "will select aicpu topk.";
return;
}
std::string op_name = AnfAlgo::GetCNodeName(kernel_node); std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
if (op_info_ptr == nullptr) { if (op_info_ptr == nullptr) {
......
...@@ -77,6 +77,9 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { ...@@ -77,6 +77,9 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetKernelType(TBE_KERNEL);
builder.SetFusionType(kernel::OPAQUE);
builder.SetProcessor(kernel::AICORE);
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
...@@ -129,10 +132,12 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod ...@@ -129,10 +132,12 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
new_cnode->add_input(indices_const); new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_); MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
MS_LOG(INFO) << "split topk failed, check to aicpu.";
return nullptr; return nullptr;
} }
if (kernel_graph != nullptr) { if (kernel_graph != nullptr) {
MS_LOG(INFO) << "split topk success. use tbe aicore.";
kernel_graph->AddValueNodeToGraph(indices_const); kernel_graph->AddValueNodeToGraph(indices_const);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册