提交 23a586b6 编写于 作者: E Etone.Chan

set RefInfo of Buffer Fusion kernel

上级 2e2e7a28
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <iterator>
#include "kernel/kernel_fusion.h" #include "kernel/kernel_fusion.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
...@@ -461,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, ...@@ -461,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
} }
} }
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
const AnfNodePtr &fusion_kernel) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
auto output = outputs_list[idx];
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
auto real_output = AnfAlgo::VisitKernel(output, 0);
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto input2 = output_cnode->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2));
session::AnfWithOutIndex out_pair(real_output.first, output_idx);
if (kernel_graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
}
} else {
session::AnfWithOutIndex out_pair(output, 0);
if (kernel_graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
}
}
}
}
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
...@@ -708,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_ ...@@ -708,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_
} }
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get());
AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get());
// replace node SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion);
ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph);
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册