diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index bdefc9bf7c2271bbc91e722f5a6df2ece76c45d5..5824f46e302d98e76084748ccd0a0bf86d4e2a58 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -62,7 +62,6 @@ #include "pre_activate/pass/common_subexpression_elimination.h" #include "pre_activate/ascend/format_type/merge_cast_to_op.h" #include "pre_activate/ascend/format_type/check_consistency.h" -#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" #include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" @@ -314,14 +313,14 @@ void AscendBackendOptimization(const std::shared_ptr &kern optimizer->AddPassManager(other_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); + // buffer fusion + AscendBackendUBFusionOptimization(kernel_graph); if (save_graphs) { std::string file_path = save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; DumpIR(file_path, kernel_graph, true); DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); } - // buffer fusion - AscendBackendUBFusionOptimization(kernel_graph); } void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc deleted file mode 100644 index e44e3dff5d02e40e1914938dde40c2f4c4fd7d06..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ /dev/null @@ -1,800 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -namespace { -const int8_t MAX_PATTERN_SIZE = 7; -const int8_t MIN_PATTERN_SIZE = 2; -const int8_t ELTWISE_INPUT_SIZE = 2; -const int8_t ELTWISE_USE = 1; -const int8_t MULTI_ELTWISE_USE = 2; -const int8_t MAX_MULTI_ELTWISE_SIZE = 4; -const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; -constexpr auto kOpAttrFusionId = "fusion_id"; - -#ifdef DEBUG -std::string GetFusionTypeName(const kernel::FusionType &type) { - switch (type) { - case kernel::FusionType::COMMREDUCE: - return "COMMREDUCE"; - case kernel::FusionType::SEGMENT: - return "SEGMENT"; - case kernel::FusionType::ELEMWISE: - return "ELEMWISE"; - case kernel::FusionType::CONVLUTION: - return "CONVLUTION"; - case kernel::FusionType::OPAQUE: - return "OPAQUE"; - default: - return "OPAQUE"; - } -} - -void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { - MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; - for (auto &node : info.input_nodes) { - MS_LOG(INFO) << "=== Input: " << node->DebugString(); - } - for (auto &node : info.output_nodes) { - MS_LOG(INFO) << "=== Output: " << node->DebugString(); - } - for (auto &node : info.compute_nodes) { - MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) - << ")"; - } - MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; -} -#endif - -bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set *record, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(record); - auto user_nodes = manager->node_users()[node]; - return (AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && - (user_nodes.size() <= ELTWISE_USE || record->size() == 0)); -} - -// Common method to check for predecessors and successors in a fusion pattern -std::tuple FindPredAndSuccEltWiseNodes(const int8_t &max_size, FuncGraphManager *manager, - std::unordered_set *visited_set, - std::deque *todo, - std::unordered_set *record, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(visited_set); - MS_EXCEPTION_IF_NULL(todo); - MS_EXCEPTION_IF_NULL(record); - MS_EXCEPTION_IF_NULL(node); - - CNodePtr new_node = node; - if (new_node->inputs().size() < ELTWISE_INPUT_SIZE) { - return std::make_tuple(false, new_node); - } - int8_t index = 1; - auto &users = manager->node_users(); - while (CheckEltWiseNode(manager, record, new_node)) { - (void)record->insert(new_node); - (void)visited_set->insert(new_node); - (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); - - auto cnode = new_node->input(1); - MS_EXCEPTION_IF_NULL(cnode); - if (!cnode->isa()) { - return std::make_tuple(false, new_node); - } - new_node = cnode->cast(); - MS_EXCEPTION_IF_NULL(new_node); - - if (!AnfAlgo::IsRealKernel(new_node) || new_node->inputs().size() < ELTWISE_INPUT_SIZE || - users[(new_node)].size() >= MULTI_ELTWISE_USE || visited_set->find(new_node) != visited_set->end()) { - return std::make_tuple(false, new_node); - } - - if (index >= max_size) { - break; - } - index++; - } - return std::make_tuple(true, new_node); -} - -std::tuple MatchGeneralPattern(FuncGraphManager *manager, std::unordered_set *record, - std::unordered_set *visited_set, - std::deque *todo, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(record); - MS_EXCEPTION_IF_NULL(visited_set); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(todo); - CNodePtr new_node = node; - auto &users = manager->node_users(); - if (users[(new_node)].size() >= MULTI_ELTWISE_USE) { - return std::make_tuple(false, new_node); - } - - (void)record->insert(node); - (void)visited_set->insert(node); - (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); - - if (node->inputs().size() < 2) { - return std::make_tuple(false, new_node); - } - // only check the first real input, will check all - auto cnode = node->input(1); - MS_EXCEPTION_IF_NULL(cnode); - if (!cnode->isa()) { - return std::make_tuple(false, new_node); - } - new_node = cnode->cast(); - MS_EXCEPTION_IF_NULL(new_node); - - if (!AnfAlgo::IsRealKernel(new_node) || users[(new_node)].size() >= MULTI_ELTWISE_USE || - visited_set->find(new_node) != visited_set->end()) { - return std::make_tuple(false, new_node); - } - return std::make_tuple(true, new_node); -} - -CNodePtr FindFusionAnfNode(FuncGraphManager *manager, std::unordered_set *visited_set, - std::unordered_set *record, std::deque *todo, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(visited_set); - MS_EXCEPTION_IF_NULL(record); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(todo); - // find fusion pattern predecessor nodes - auto ret = FindPredAndSuccEltWiseNodes(MAX_MULTI_ELTWISE_SIZE, manager, visited_set, todo, record, node); - auto new_node = std::get<1>(ret); - auto node_use_size = manager->node_users()[new_node].size(); - if (!std::get<0>(ret) || (record->size() > 1 && node_use_size > 1) || record->size() >= MAX_MULTI_ELTWISE_SIZE || - AnfAlgo::GetKernelType(new_node) != KernelType::TBE_KERNEL) { - return new_node; - } - - // key of fusion precessor - auto node_fusion_type = AnfAlgo::GetFusionType(new_node); - switch (node_fusion_type) { - case kernel::FusionType::COMMREDUCE: - case kernel::FusionType::SEGMENT: - ret = MatchGeneralPattern(manager, record, visited_set, todo, new_node); - new_node = std::get<1>(ret); - if (!std::get<0>(ret)) { - return new_node; - } - break; - case kernel::FusionType::ELEMWISE: - return new_node; - // -fallthrough to default and return - case kernel::FusionType::CONVLUTION: - (void)record->insert(new_node); - default: - (void)visited_set->insert(new_node); - if (new_node != nullptr) { - (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); - } - return new_node; - } - // find fusion pattern successor nodes - ret = FindPredAndSuccEltWiseNodes(MAX_PURE_BUFFER_SUCC_SIZE, manager, visited_set, todo, record, new_node); - return std::get<1>(ret); -} - -CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, - const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { - MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; - MS_EXCEPTION_IF_NULL(kernel_graph); - std::string fusion_op_name = "FusionOp"; - for (auto node : anf_nodes) { - fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); - } - auto fusion_op = std::make_shared(fusion_op_name); - MS_EXCEPTION_IF_NULL(fusion_op); - - std::vector input_names; - for (uint8_t i = 0; i < inputs_list.size(); i++) { - input_names.emplace_back("input" + std::to_string(i)); - } - std::vector output_names; - for (uint8_t i = 0; i < outputs_list.size(); i++) { - output_names.emplace_back("output" + std::to_string(i)); - } - - ValuePtr input_names_v = MakeValue(input_names); - ValuePtr output_names_v = MakeValue(output_names); - fusion_op->set_attr("input_names", input_names_v); - fusion_op->set_attr("output_names", output_names_v); - std::vector fusion_inputs_list = inputs_list; - auto value_node = std::make_shared(fusion_op); - (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); - auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); - if (buffer_fusion_kernel == nullptr) { - MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; - } - buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); - - return buffer_fusion_kernel; -} - -kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, - const std::vector &outputs_list) { - MS_LOG(DEBUG) << "Start Create Kernel Info"; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - // inputs format and data type - std::vector inputs_format; - std::vector inputs_data_type; - for (const auto &input : inputs_list) { - auto real_input = AnfAlgo::VisitKernel(input, 0); - inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); - inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); - } - // outputs format and data type - std::vector outputs_format; - std::vector outputs_data_type; - for (const auto &output : outputs_list) { - if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto tuple_getitem = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - outputs_format.push_back(AnfAlgo::GetOutputFormat( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - } else { - outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); - } - } - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_data_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_data_type); - builder.SetKernelType(KernelType::TBE_KERNEL); - return builder.Build(); -} - -AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, - size_t output_index) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector tuple_getitem_inputs_list; - auto value = std::make_shared(prim::kPrimTupleGetItem); - MS_EXCEPTION_IF_NULL(value); - auto idx = NewValueNode(SizeToInt(output_index)); - MS_EXCEPTION_IF_NULL(idx); - int temp = SizeToInt(output_index); - auto imm = std::make_shared(temp); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - tuple_getitem_inputs_list.push_back(value); - tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); - tuple_getitem_inputs_list.push_back(idx); - auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); - MS_EXCEPTION_IF_NULL(tuple_item); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, - {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, - tuple_item.get()); - return tuple_item; -} - -void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const AnfNodePtr &output_item, - const AnfNodePtr &replace_item) { - for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { - auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), - output_item); - if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { - MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; - *itr = replace_item; - } - } -} - -void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - if (buffer_fusion_info.outputs_list.size() == 1) { // single output - (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], - buffer_fusion_kernel); - } else { // multiple output - for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { - auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); - (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], - tuple_item); - } - } -} - -void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto nodes = TopoSort(kernel_graph->get_return()); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { - auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); - (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); - } - } -} - -void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - auto cnode = node->cast(); - for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { - auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == - fusion_info.anf_nodes.end()) { - if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), - (*buffer_fusion_infos)[fusion_id].inputs_list.end(), - cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { - (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); - } - } - } - } - } -} - -bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - auto getitem1 = node1->cast(); - auto getitem2 = node2->cast(); - MS_EXCEPTION_IF_NULL(getitem1); - MS_EXCEPTION_IF_NULL(getitem2); - auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); - auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); - return output_idx1 < output_idx2; -} - -void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - if (AnfAlgo::GetOutputTensorNum(node) == 1) { - for (auto use_node : manager->node_users()[node]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); - break; - } - } - } else { - int prev_idx = 0; - std::vector tuple_getitem_nodes; - std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), - std::back_inserter(tuple_getitem_nodes), - [](const std::pair &use_node) { return use_node.first; }); - std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); - for (auto getitem : tuple_getitem_nodes) { - auto getitem_ptr = getitem->cast(); - auto input2 = getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { - auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); - } - prev_idx = output_idx + 1; - for (auto item_use_node : manager->node_users()[getitem]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); - break; - } - } - } - } - } - } -} - -void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, - const AnfNodePtr &fusion_kernel) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (size_t idx = 0; idx < outputs_list.size(); ++idx) { - auto output = outputs_list[idx]; - if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto real_output = AnfAlgo::VisitKernel(output, 0); - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto input2 = output_cnode->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - session::AnfWithOutIndex out_pair(real_output.first, output_idx); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } else { - session::AnfWithOutIndex out_pair(output, 0); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } - } -} -} // namespace - -void BufferFusion::SetRecordFusionId(const std::unordered_set &record) { - auto id = fusion_id_allocator.AllocateFusionId(); - for (auto node : record) { - fusion_id_allocator.SetFusionId(node, id); - } -} - -void BufferFusion::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto conv = cnode->input(1); - if (conv->isa() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { - std::vector output_used_num{SizeToInt(manager->node_users()[conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); - std::unordered_set record{cnode, conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void BufferFusion::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto getitem = relu_input->cast(); - auto bnupdate = getitem->input(1); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - auto out_getitem_ptr = out_getitem.first->cast(); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto add = relu_input->cast(); - MS_EXCEPTION_IF_NULL(add); - auto tuple_getitem = add->input(1); - if (tuple_getitem->isa() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { - auto getitem = tuple_getitem->cast(); - auto bnupdate = getitem->input(1); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - auto out_getitem_ptr = out_getitem.first->cast(); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, relu_input, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - if (is_order) { - // DepthwiseConvolution--->Elemwise - auto depthwise_conv = cnode->input(1); - MS_EXCEPTION_IF_NULL(depthwise_conv); - if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { - std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); - std::unordered_set record{cnode, depthwise_conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } else { - // Elemwise-->DepthwiseConvolution - auto relu = cnode->input(1); - MS_EXCEPTION_IF_NULL(relu); - if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { - std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); - std::unordered_set record{cnode, relu}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void BufferFusion::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); - std::unordered_set record{cnode, relu_input}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { - MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); - } else if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { - MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { - MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } else if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { - MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } else if (eltwise_input->isa() && - AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); - } - } - } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); - } - } -} - -void BufferFusion::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(candidate_fusion); - - auto return_node = kernel_graph.get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() <= 1) { - return; - } - std::deque todo; - todo.push_back(return_node->input(1)); - std::unordered_set visited_set; - - while (!todo.empty()) { - auto node = todo.front(); - MS_EXCEPTION_IF_NULL(node); - todo.pop_front(); - std::unordered_set record; - if (visited_set.find(node) != visited_set.end() || fusion_id_allocator.HasFusionIdAttr(node)) { - continue; - } - // Only fuse real cnode - if (!AnfAlgo::IsRealCNodeKernel(node)) { - auto cnode = node->cast(); - if (cnode != nullptr) { - (void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - } - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // cnode maybe updated - cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode); - if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) { - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - if (record.find(cnode) == record.end()) { - todo.push_back(cnode); - } - // no node matched - if (record.size() == 0) { - (void)visited_set.insert(node); - } - (void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - } -} - -void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); - GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); - GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - buffer_fusion_info.second.kernel_build_info = - CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); - } -} - -bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - bool change = false; - std::unordered_map buffer_fusion_infos; - buffer_fusion_infos.clear(); - GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); - - std::vector fusion_scope_infos; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - mindspore::kernel::FusionScopeInfo fusion_scope_info; - fusion_scope_info.scope_id = buffer_fusion_info.first; - fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; - fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; - fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; - fusion_scope_infos.push_back(fusion_scope_info); -#ifdef DEBUG - DumpFusionScopeInfo(fusion_scope_info); -#endif - } - auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); - - std::vector fusion_ids; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() - << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() - << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); - fusion_ids.push_back(buffer_fusion_info.first); - } - // Replace fusion op from return to head - std::sort(fusion_ids.begin(), fusion_ids.end()); - for (auto &fusion_id : fusion_ids) { - // Get kernel mod when supporting tbe - if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { - MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; - continue; - } - change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); - } - MS_LOG(DEBUG) << "End Buffer Fusion"; - return change; -} - -bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) { - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto return_node = kernel_graph.get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() <= 1) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; - FusedNodeRecord candidate_fusion; - - MatchOpNamePattern(kernel_graph, &candidate_fusion); - MatchFusionTypePattern(kernel_graph, &candidate_fusion); - - if (candidate_fusion.empty()) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; - return true; -} - -bool BufferFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, - session::KernelGraph *kernel_graph) const { - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, - buffer_fusion_info.anf_nodes, kernel_graph); - AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); - // Set abstract of fusion_op node - std::vector types; - std::vector> shapes; - for (const auto &out_node : buffer_fusion_info.outputs_list) { - for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { - types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); - shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); - } - } - if (types.empty() || shapes.empty()) { - MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; - return false; - } - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); - AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); - SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); - ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); - return true; -} - -bool BufferFusion::Run(const FuncGraphPtr &graph) { - bool changed = false; - MS_EXCEPTION_IF_NULL(graph); - auto kernel_graph = graph->cast>(); - MS_EXCEPTION_IF_NULL(kernel_graph); - - fusion_id_allocator.Init(); - if (MatchBufferFusionPattern(*kernel_graph)) { - changed = FuseBufferFusionPattern(kernel_graph.get()); - } - // clear fusion_id attr - for (auto &node : graph->nodes()) { - if (node != nullptr && node->isa()) { - AnfAlgo::EraseNodeAttr(kAttrFusionId, node); - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h deleted file mode 100644 index 008d072ed3f47e9f1e340ddd04cb34d8a60151a8..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_ -#include -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -struct BufferFusionInfo_t { - std::vector anf_nodes; - std::vector inputs_list; - std::vector outputs_list; - kernel::KernelBuildInfoPtr kernel_build_info; -}; - -using FusedNodeRecord = std::vector>; - -class BufferFusion : public Pass { - public: - BufferFusion() : Pass("buffer_fusion") {} - ~BufferFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - void SetRecordFusionId(const std::unordered_set &record); - void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); - void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); - void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); - void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order); - void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); - void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); - void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); - - void GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const; - bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; - bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph); - bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; - - FusionIdAllocator fusion_id_allocator; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_BUFFER_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h index 8d2ed816077e338ad635256696b70a0def619040..421efa9716943d630580e13f01e451fd0509ce2c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h @@ -37,6 +37,13 @@ const int8_t ELTWISE_USE = 1; const int8_t MAX_ELTWISE_SIZE = 6; using FusedNodeRecord = std::vector>; +struct BufferFusionInfo_t { + std::vector anf_nodes; + std::vector inputs_list; + std::vector outputs_list; + kernel::KernelBuildInfoPtr kernel_build_info; +}; + class FusionBasePass : public Pass { public: FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc index 14f26b85acf7026707c073d1a3011bf167ca74a5..2293754106b57cef39a0a6721daa624a9172834c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc @@ -15,6 +15,7 @@ */ #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" #include +#include #include #include #include @@ -51,7 +52,9 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) { (void)record.insert(eltwise_input); - auto previous_eltwise_input = cnode->input(1); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); auto previous_size = record.size(); while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { (void)record.insert(previous_eltwise_input); @@ -71,6 +74,7 @@ void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); for (auto &node : node_list) { if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc index 329f5eb1a4b9e7f28b69cd744afc0553b3b37ca2..c1c4df0167e354f9ee8b185452905489f2e44de6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc @@ -15,6 +15,7 @@ */ #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include +#include #include #include #include @@ -51,7 +52,9 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) { (void)record.insert(eltwise_input); - auto previous_eltwise_input = cnode->input(1); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); auto previous_size = record.size(); while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { (void)record.insert(previous_eltwise_input); @@ -71,6 +74,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); for (auto &node : node_list) { if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h index 4c2f91472e81679205145c146da5c3c82923f6ce..7099c92772f09ec46aec808854dc155ff3d1f9ac 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h @@ -19,13 +19,13 @@ #include #include +#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" #include "ir/anf.h" #include "pre_activate/common/pass.h" #include "pre_activate/common/fusion_id_allocator.h" #include "device/kernel_info.h" #include "kernel/kernel.h" #include "session/kernel_graph.h" -#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" namespace mindspore { namespace opt { diff --git a/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc b/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc deleted file mode 100644 index 8215fdff90cd36c9d14bb4703abe1f6c2bf80015..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/remove_nop_nodes.h" - -#include "common/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const AnfNodePtr RemoveNopNodes::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsNopNode(node)) { - return nullptr; - } - return cnode->input(1); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h b/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h deleted file mode 100644 index f0b68642e9eb41df1e2503de05cf52cc782f45ca..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/pre_activate/pass/remove_nop_nodes.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_ -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class RemoveNopNodes : public PatternProcessPass { - public: - explicit RemoveNopNodes(bool multigraph = true) : PatternProcessPass("remove_nop_nodes", multigraph) {} - ~RemoveNopNodes() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REMOVE_NOP_NODES_H_ diff --git a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc index aa548a5351ae5d3db65e1d1cef98c70c56d2d170..483c144930e4a7a0f1c55a87a9b18dff470f35e7 100644 --- a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc @@ -21,7 +21,19 @@ #include "device/kernel_info.h" #include "pre_activate/common/optimizer.h" #include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" +#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" namespace mindspore { namespace opt { @@ -79,10 +91,13 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) { cast->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get()); + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - auto buffer_fusion_pass = std::make_shared(); - pm->AddPass(buffer_fusion_pass); + pm->AddPass(std::make_shared(fusion_id_allocator)); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg); @@ -168,10 +183,13 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) { biasadd->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), biasadd.get()); + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - auto buffer_fusion_pass = std::make_shared(); - pm->AddPass(buffer_fusion_pass); + pm->AddPass(std::make_shared(fusion_id_allocator)); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg); @@ -255,10 +273,13 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) { biasaddgrad->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), biasaddgrad.get()); + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - auto buffer_fusion_pass = std::make_shared(); - pm->AddPass(buffer_fusion_pass); + pm->AddPass(std::make_shared(fusion_id_allocator)); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg); @@ -321,10 +342,13 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) { cast->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get()); + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - auto buffer_fusion_pass = std::make_shared(); - pm->AddPass(buffer_fusion_pass); + pm->AddPass(std::make_shared(fusion_id_allocator)); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(kg);