提交 2b5b35ea 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4125 anf fusion

Merge pull request !4125 from zhengjun10/master
......@@ -65,6 +65,7 @@ if(BUILD_CONVERTER)
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc
)
else()
......@@ -212,9 +213,8 @@ if(BUILD_CONVERTER)
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc
${LITE_DIR}/tools/optimizer/common/node_pass.cc
${LITE_DIR}/tools/optimizer/common/optimizer.cc
${LITE_DIR}/tools/optimizer/common/pass_manager.cc
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_activation_fusion.cc
......
......@@ -4,6 +4,10 @@ hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite
hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite
hiai_cn_recognize_modify_padv2.tflite
hiai_model_normalize_object_scene_ps_20200519.tflite
hiai_detectmodel_06_23_960_480_1180700.tflite
hiai_detect_curve_model_float32.tflite
hiai_detectmodel_desnet_256_128_64_32.tflite
mtk_AADB_HADB_MBV2_model_fp32.tflite
mobilenet_v1_0.25_128.tflite
mobilenet_v1_0.25_160.tflite
mobilenet_v1_0.25_192.tflite
......
......@@ -51,6 +51,7 @@ set(ANF_SRC
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc
)
......@@ -74,9 +75,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc
../optimizer/common/node_pass.cc
../optimizer/common/optimizer.cc
../optimizer/common/pass_manager.cc
../optimizer/common/node_pass_extends.cc
../optimizer/common/pass_manager_extends.cc
../optimizer/common/gllo_utils.cc
../optimizer/fusion/conv_biasadd_fusion.cc
../optimizer/fusion/conv_activation_fusion.cc
......
......@@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return nullptr;
}
graph = anfTransform->Transform(graph);
// graph = anfTransform->Transform(graph);
CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) {
......
......@@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// }
// fusion
// {
// Optimizer fusionOptimizer;
// fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// status = fusionOptimizer.Run(graphDefT);
// if (status != RET_OK && status != RET_NO_CHANGE) {
// MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
// return status;
// }
// }
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
}
// weight format trans
if (ctx.formatTrans) {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/common/optimizer.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <initializer_list>
#include "backend/optimizer/common/pass_manager.h"
#include "ir/manager.h"
namespace mindspore {
namespace opt {
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name),
multigraph_(multigraph),
pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
const BaseRef PatternProcessPass::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return BaseRef({X});
}
void PatternProcessPass::Build() {
VarPtr fg = std::make_shared<Var>("RootG");
BaseRef pattern = std::move(DefinePattern());
pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
}
AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (pattern_ == nullptr) {
Build();
}
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(primitive_vars_);
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv);
}
return nullptr;
}
bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
VarPtr fg = std::make_shared<Var>("RootG");
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
EquivPtr another_equiv =
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
*child_primitive_vars_, empty_equiv);
if (another_equiv != nullptr && !another_equiv->empty()) {
return IsShareNodes(equiv, another_equiv);
}
return false;
}
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
if (pass_manager != nullptr) {
pass_managers_.push_back(pass_manager);
}
}
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
MS_EXCEPTION_IF_NULL(func_graph);
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, false);
func_graph->set_manager(manager);
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 0; i < pass_managers_.size(); ++i) {
const PassManagerPtr &pm = pass_managers_[i];
if (pm != nullptr && pm->Run(func_graph)) {
changed = true;
}
}
if (run_only_once_) {
break;
}
}
std::vector<FuncGraphPtr> func_graphs;
func_graphs.push_back(func_graph);
manager->KeepRoots(func_graphs);
(void)TopoSort(func_graph->get_return());
return func_graph;
}
} // namespace opt
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册