未验证 提交 c761b48b 编写于 作者: Z Zhen Wang 提交者: GitHub

Apply TransposeFolding & GemmRewriter passes. (#41084)

上级 922e076e
......@@ -26,7 +26,7 @@ add_definitions(-w)
######################################
include(ExternalProject)
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
set(CINN_GIT_TAG e11c5e672f9961e28cfa403d86f99808beb58817)
set(CINN_GIT_TAG 1fd85187b6c18da4dd51f22619d093ef08d61b01)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
-DWITH_CUDA=${WITH_GPU}
-DWITH_CUDNN=${WITH_GPU}
......
......@@ -223,9 +223,12 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const Target& target, std::int64_t compiled_num, void* stream) const {
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol();
ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto fetch_ids = symbol.GetFetchIds();
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"Decomposer"});
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity");
::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "TransposeFolding");
ProgramPass::Apply(&frontend_program, fetch_ids, target, {"GemmRewriter"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target);
VLOG(1) << "-- The " << compiled_num << "-th compilation ("
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册