未验证 提交 2f2f987c 编写于 作者: Z Zhen Wang 提交者: GitHub

[Cherry-Pick]Move pass optimizations into CINN. (#42047) (#42070)

* Move pass optimizations into CINN.
上级 dbdb56d1
...@@ -26,7 +26,7 @@ add_definitions(-w) ...@@ -26,7 +26,7 @@ add_definitions(-w)
###################################### ######################################
include(ExternalProject) include(ExternalProject)
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
set(CINN_GIT_TAG 08d7680dd91dfaa65787969050eb8f1143654f10) set(CINN_GIT_TAG release/v0.2)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
-DWITH_CUDA=${WITH_GPU} -DWITH_CUDA=${WITH_GPU}
-DWITH_CUDNN=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU}
......
...@@ -25,14 +25,10 @@ ...@@ -25,14 +25,10 @@
#include "cinn/auto_schedule/tuning.h" #include "cinn/auto_schedule/tuning.h"
#include "cinn/common/target.h" #include "cinn/common/target.h"
#include "cinn/common/type.h" #include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h" #include "cinn/frontend/optimize.h"
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h" #include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -58,13 +54,11 @@ namespace paddle2cinn { ...@@ -58,13 +54,11 @@ namespace paddle2cinn {
using ir::Graph; using ir::Graph;
using ir::Node; using ir::Node;
using inference::analysis::Dot; using inference::analysis::Dot;
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::auto_schedule::AutoTuner; using ::cinn::auto_schedule::AutoTuner;
using ::cinn::common::Target;
using ::cinn::frontend::Optimize;
using ::cinn::hlir::framework::BuildScope; using ::cinn::hlir::framework::BuildScope;
using ::cinn::frontend::ProgramPass; using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::ApplyPass;
CinnCompiler* CinnCompiler::GetInstance() { CinnCompiler* CinnCompiler::GetInstance() {
static CinnCompiler instance; static CinnCompiler instance;
...@@ -75,7 +69,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -75,7 +69,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph, const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target, void* stream) { const Target& target, void* stream) {
VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph); VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors, CinnCacheKeyByAddress cur_key_by_address(graph, input_tensors,
target.arch_str()); target.arch_str());
CinnCacheKeyByStructure cur_key_by_struct; CinnCacheKeyByStructure cur_key_by_struct;
...@@ -258,22 +252,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -258,22 +252,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol(); auto frontend_program = symbol();
auto fetch_ids = symbol.GetFetchIds(); auto fetch_ids = symbol.GetFetchIds();
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"}); VLOG(4) << "All fetch var ids in CINN: "
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); << string::join_strings(fetch_ids, ',');
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding");
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( auto cinn_graph = Optimize(&frontend_program, fetch_ids, target);
frontend_program, target); VLOG(4) << "-- The " << compiled_num << "-th compilation ("
VLOG(1) << "-- The " << compiled_num << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n" << target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize(); << cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph);
VLOG(4) << "All fetch var ids in CINN: "
<< string::join_strings(fetch_ids, ',');
auto scope = BuildScope(target, cinn_graph);
auto graph_compiler = auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph); std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options; GraphCompiler::CompileOptions options;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册