提交 38e0d98e 编写于 作者: E etone-chan

refactor fusion id implement of buffer fusion

上级 67013077
......@@ -32,6 +32,7 @@
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"
#include "pre_activate/common/fusion_id_allocator.h"
namespace mindspore {
namespace opt {
......@@ -79,20 +80,6 @@ void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
}
#endif
void SetAnfNodeFusionId(const FusedNodeRecord &record_node) {
MS_LOG(DEBUG) << "Size of opt vector to be fused is " << record_node.size();
int32_t id = 1;
for (auto &record : record_node) {
MS_LOG(DEBUG) << "No" << id << ", opt vector to be fused contain " << record.size() << " opt.";
for (const auto &candidate : record) {
ValuePtr fusion_id_v = MakeValue(id);
AnfAlgo::SetNodeAttr(kOpAttrFusionId, fusion_id_v, candidate);
MS_LOG(DEBUG) << "No " << id << ": " << candidate->DebugString();
}
id++;
}
}
bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *record, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(record);
......@@ -482,11 +469,18 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
}
}
}
} // namespace
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
void BufferFusion::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) {
auto id = fusion_id_allocator.AllocateFusionId();
for (auto node : record) {
fusion_id_allocator.SetFusionId(node, id);
}
}
void BufferFusion::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(fused_set);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -496,14 +490,13 @@ void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);
std::unordered_set<AnfNodePtr> record{cnode, conv};
candidate_fusion->push_back(record);
fused_set->insert(record.begin(), record.end());
SetRecordFusionId(record);
}
}
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
void BufferFusion::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(fused_set);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -520,14 +513,13 @@ void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, cons
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
candidate_fusion->push_back(record);
fused_set->insert(record.begin(), record.end());
SetRecordFusionId(record);
}
}
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(fused_set);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -548,41 +540,37 @@ void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, c
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate};
candidate_fusion->push_back(record);
fused_set->insert(record.begin(), record.end());
SetRecordFusionId(record);
}
}
}
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(fused_set);
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fused_set->find(node) != fused_set->end()) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) {
MatchConvBnreduce(cnode, kernel_graph, fused_set, candidate_fusion);
MatchConvBnreduce(cnode, kernel_graph, candidate_fusion);
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) {
auto relu_input = cnode->input(1);
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) {
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
MatchBnupdateRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
}
}
}
}
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
FusedNodeRecord *candidate_fusion) {
void BufferFusion::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(fused_set);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto return_node = kernel_graph.get_return();
......@@ -599,7 +587,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
MS_EXCEPTION_IF_NULL(node);
todo.pop_front();
std::unordered_set<AnfNodePtr> record;
if (visited_set.find(node) != visited_set.end() || fused_set->find(node) != fused_set->end()) {
if (visited_set.find(node) != visited_set.end() || fusion_id_allocator.HasFusionIdAttr(node)) {
continue;
}
// Only fuse real cnode
......@@ -616,7 +604,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode);
if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) {
candidate_fusion->push_back(record);
fused_set->insert(record.begin(), record.end());
SetRecordFusionId(record);
}
if (record.find(cnode) == record.end()) {
todo.push_back(cnode);
......@@ -628,7 +616,6 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
(void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
}
}
} // namespace
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
......@@ -684,7 +671,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
return change;
}
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const {
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto return_node = kernel_graph.get_return();
......@@ -694,14 +681,11 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
}
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
FusedNodeRecord candidate_fusion;
std::unordered_set<AnfNodePtr> fused_set;
MatchOpNamePattern(kernel_graph, &fused_set, &candidate_fusion);
MatchFusionTypePattern(kernel_graph, &fused_set, &candidate_fusion);
MatchOpNamePattern(kernel_graph, &candidate_fusion);
MatchFusionTypePattern(kernel_graph, &candidate_fusion);
if (!candidate_fusion.empty()) {
SetAnfNodeFusionId(candidate_fusion);
} else {
if (candidate_fusion.empty()) {
return false;
}
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
......@@ -741,13 +725,14 @@ bool BufferFusion::Run(const FuncGraphPtr &graph) {
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
fusion_id_allocator.Init();
if (MatchBufferFusionPattern(*kernel_graph)) {
changed = FuseBufferFusionPattern(kernel_graph.get());
}
// clear fusion_id attr
for (auto &node : graph->nodes()) {
if (node != nullptr && node->isa<CNode>()) {
AnfAlgo::EraseNodeAttr(kOpAttrFusionId, node);
AnfAlgo::EraseNodeAttr(kAttrFusionId, node);
}
}
return changed;
......
......@@ -21,6 +21,7 @@
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
#include "pre_activate/common/fusion_id_allocator.h"
#include "device/kernel_info.h"
#include "kernel/kernel.h"
#include "session/kernel_graph.h"
......@@ -43,12 +44,24 @@ class BufferFusion : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record);
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const;
bool ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, int32_t fusion_id,
const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const;
bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const;
bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph);
bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const;
FusionIdAllocator fusion_id_allocator;
};
} // namespace opt
} // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/common/fusion_id_allocator.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; }
FusionIdAllocator::~FusionIdAllocator() {}
void FusionIdAllocator::Init() { fusion_id = 0; }
int32_t FusionIdAllocator::AllocateFusionId() {
fusion_id++;
return fusion_id;
}
bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { return AnfAlgo::HasNodeAttr(kAttrFusionId, node); }
int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) {
if (HasFusionIdAttr(node)) {
return AnfAlgo::GetNodeAttr<int32_t>(node, kAttrFusionId);
}
return -1;
}
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) {
ValuePtr fusion_id_v = MakeValue(id);
AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
#include "ir/base.h"
namespace mindspore {
namespace opt {
class FusionIdAllocator {
public:
FusionIdAllocator();
virtual ~FusionIdAllocator();
FusionIdAllocator(const FusionIdAllocator &in) = delete;
FusionIdAllocator &operator=(const FusionIdAllocator &in) = delete;
void Init();
int32_t AllocateFusionId();
bool HasFusionIdAttr(const AnfNodePtr &node);
int32_t GetFusionId(const AnfNodePtr &node);
void SetFusionId(const AnfNodePtr &node, int32_t id);
private:
int32_t fusion_id;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
......@@ -165,6 +165,7 @@ constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrOp = "op";
constexpr auto kAttrIsTraining = "is_training";
constexpr auto kAttrFusionId = "fusion_id";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册