未验证 提交 2d345148 编写于 作者: F Fisher 提交者: GitHub

[CINN] Optimize parallel compiler and support dumping more compilation information (#55590)

graph_compiler_util.h/cc:
整合GraphCompiler与ParallelCompiler共同持有的数据结构CompilationStage、CompilationStatus、CompilationContext、CompilationResult
Parallel Compiler:
整合数据结构至CompilationContext
支持分阶段编译,通过指定CompilationContext::Stage
添加编译状态信息,包括状态CompilationResult::Status和信息CompilationResult::message
一个Task对应一个fusion_group,每一阶段编译完成后,将编译结果放入CompilationResult数组的对应下标中,删去原Task中内部的局部变量,省去了MergeResult
Graph Compiler:
丰富CompilationResult,提供编译各阶段的中间结果
整合数据结构至CompilationContext
添加编译状态信息,包括状态CompilationResult::Status和信息CompilationResult::message
启用单测
其他:相关单测、前端接口、paddle2cinn适配CompilationContext
上级 88a975a0
......@@ -26,6 +26,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/ir_base.h"
......@@ -37,6 +38,7 @@ namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
......@@ -53,6 +55,7 @@ class TestAutoTuner : public ::testing::Test {
std::shared_ptr<Graph> graph;
std::shared_ptr<Scope> compiled_scope;
CompilationContext context;
std::unique_ptr<GraphCompiler> graph_compiler;
std::unique_ptr<AutoTuner> tuner;
......@@ -73,8 +76,10 @@ class TestAutoTuner : public ::testing::Test {
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
context.graph = graph;
context.scope = compiled_scope;
context.target = target;
graph_compiler = std::make_unique<GraphCompiler>(context);
tuner = std::make_unique<AutoTuner>(target, graph.get());
}
......@@ -99,16 +104,14 @@ class TestAutoTuner : public ::testing::Test {
virtual void ApplyTunedAndRun(const TuningResult& result) {
// build runtime program with tuning result
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
compile_options.Apply(result);
ASSERT_EQ(1, compile_options.groups.size());
ASSERT_EQ(1, compile_options.lowered_funcs.size());
context.with_instantiate_variables = true;
context.ApplyTuningResult(result);
ASSERT_EQ(1, context.groups.size());
ASSERT_EQ(1, context.lowered_funcs.size());
VLOG(6) << "Print lowered_funcs before building";
VLOG(6) << compile_options.lowered_funcs[0][0];
VLOG(6) << compile_options.lowered_funcs[1][0];
auto runtime_program =
graph_compiler->Build(compile_options).runtime_program;
VLOG(6) << context.lowered_funcs[0][0];
VLOG(6) << context.lowered_funcs[1][0];
auto runtime_program = graph_compiler->Build(&context).runtime_program;
ASSERT_EQ(1, runtime_program->size());
runtime_program->Execute();
}
......
......@@ -25,12 +25,14 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/runtime/flags.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
......@@ -62,7 +64,8 @@ class TestMeasurer : public ::testing::Test {
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
CompilationContext context(graph, scope, target);
graph_compiler = std::make_unique<GraphCompiler>(context);
TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict =
......
......@@ -17,6 +17,8 @@
namespace cinn {
namespace auto_schedule {
using hlir::framework::CompilationContext;
using hlir::framework::CompilationResult;
using hlir::framework::GraphCompiler;
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
......@@ -25,15 +27,14 @@ SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
BuildResult SimpleBuilder::Build(const MeasureInput& input) {
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr))
<< "empty handle to GraphCompiler";
GraphCompiler::CompileOptions compile_options;
compile_options.groups.emplace_back(input.task->subgraph);
compile_options.lowered_funcs.emplace_back(input.lowered_funcs);
compile_options.remove_unused_variables = false;
CompilationContext& context = graph_compiler_->GetCompilationContext();
context.groups.emplace_back(input.task->subgraph);
context.lowered_funcs.emplace_back(input.lowered_funcs);
context.remove_unused_variables = false;
VLOG(5) << "call GraphCompiler to Build with Graph::Group size="
<< compile_options.groups.size() << ", lowered_funcs group size="
<< compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result =
graph_compiler_->Build(compile_options);
<< context.groups.size()
<< ", lowered_funcs group size=" << context.lowered_funcs.size();
CompilationResult compiled_result = graph_compiler_->Build(&context);
BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get();
......
......@@ -16,6 +16,7 @@
#include "paddle/cinn/auto_schedule/measure/measure.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
namespace cinn {
namespace auto_schedule {
......
......@@ -25,11 +25,13 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
......@@ -56,8 +58,8 @@ class TestSimpleRunner : public ::testing::Test {
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
CompilationContext context(graph, compiled_scope, target);
graph_compiler = std::make_unique<GraphCompiler>(context);
auto runtime_program = graph_compiler->Build();
const auto& instructions = runtime_program->GetRunInstructions();
ASSERT_EQ(1, instructions.size());
......@@ -123,8 +125,8 @@ TEST_F(TestSimpleRunner, TimeMeasured) {
"sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize();
build_result.runtime_program.reset(
new hlir::framework::Program(nullptr, std::move(instructions)));
build_result.runtime_program = std::make_unique<hlir::framework::Program>(
nullptr, std::move(instructions));
// to skip the condition check of params in Instruction::PreparePodArgs
std::map<std::string, cinn_pod_value_t> preset_args;
......
......@@ -25,6 +25,7 @@
#include "paddle/cinn/frontend/paddle_model_convertor.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/ir_base.h"
......@@ -61,6 +62,7 @@ namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
......@@ -94,8 +96,8 @@ class PerformanceTester : public ::testing::Test {
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
VLOG(3) << "Build " << schedule_name << " program.";
auto scope = BuildScope(target_, graph);
auto graph_compiler =
std::make_unique<GraphCompiler>(target_, scope, graph);
CompilationContext context(graph, scope, target_);
auto graph_compiler = std::make_unique<GraphCompiler>(context);
auto runtime_program =
(this->*build_fn)(graph.get(), graph_compiler.get());
if (execute) {
......@@ -145,16 +147,16 @@ class PerformanceTester : public ::testing::Test {
std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target_);
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
CompilationContext& context = graph_compiler->GetCompilationContext();
context.with_instantiate_variables = true;
if (graph->fusion_groups.empty()) {
hlir::framework::ApplyPasses(graph, {"BuildNonFusedGroupsPass"});
}
compile_options.groups = graph->fusion_groups;
context.groups = graph->fusion_groups;
for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back(
context.lowered_funcs.push_back(
op_lowerer->Lower(group,
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false));
......@@ -162,7 +164,7 @@ class PerformanceTester : public ::testing::Test {
VLOG(3) << "===========================No Schedule LoweredFunc "
"Begin===========================";
for (const auto& funcvec : compile_options.lowered_funcs) {
for (const auto& funcvec : context.lowered_funcs) {
for (const auto& func : funcvec) {
VLOG(3) << func;
}
......@@ -170,7 +172,7 @@ class PerformanceTester : public ::testing::Test {
VLOG(3) << "===========================No Schedule LoweredFunc "
"End=============================";
return graph_compiler->Build(compile_options).runtime_program;
return graph_compiler->Build();
}
std::unique_ptr<hlir::framework::Program> BuildManualScheduleProgram(
......@@ -191,13 +193,13 @@ class PerformanceTester : public ::testing::Test {
tuner->Initialize(tuning_config, graph_compiler);
TuningResult tuning_result = tuner->Tune(tuning_options);
GraphCompiler::CompileOptions compile_options;
compile_options.with_instantiate_variables = true;
compile_options.Apply(tuning_result);
CompilationContext& context = graph_compiler->GetCompilationContext();
context.with_instantiate_variables = true;
context.ApplyTuningResult(tuning_result);
VLOG(3) << "===========================Auto Schedule LoweredFunc "
"Begin===========================";
for (const auto& funcvec : compile_options.lowered_funcs) {
for (const auto& funcvec : context.lowered_funcs) {
for (const auto& func : funcvec) {
VLOG(3) << func;
}
......@@ -205,7 +207,7 @@ class PerformanceTester : public ::testing::Test {
VLOG(3) << "===========================Auto Schedule LoweredFunc "
"End=============================";
return graph_compiler->Build(compile_options).runtime_program;
return graph_compiler->Build();
}
#ifdef CINN_WITH_CUDA
......
......@@ -43,8 +43,7 @@ namespace backends {
*/
class CompilationInfoDumper {
public:
explicit CompilationInfoDumper(
const hlir::framework::ParallelCompiler::CompilationResult& info)
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info)
: info_(info) {
DumpLoweredFunc();
DumpSourceCode();
......@@ -62,7 +61,7 @@ class CompilationInfoDumper {
const std::string& file_name,
const std::string& content);
const hlir::framework::ParallelCompiler::CompilationResult& info_;
const hlir::framework::CompilationResult& info_;
};
class SourceCodePrint {
......
......@@ -72,16 +72,18 @@ std::shared_ptr<ComputationContext> CompileProgram(
}
ctx->scope = hlir::framework::BuildScope(target, ctx->graph, scope);
ctx->graph_compiler.reset(
new hlir::framework::GraphCompiler(target, ctx->scope, ctx->graph));
std::unordered_set<std::string> fetch_var_ids;
for (auto &out : outputs) {
fetch_var_ids.insert(out->id);
}
ctx->program = ctx->graph_compiler->Build(options, std::move(fetch_var_ids))
.runtime_program;
ctx->compile_options.graph = ctx->graph;
ctx->compile_options.scope = ctx->scope;
ctx->compile_options.fetch_var_ids = fetch_var_ids;
ctx->graph_compiler.reset(
new hlir::framework::GraphCompiler(ctx->compile_options));
ctx->program = ctx->graph_compiler->Build();
if (ctx->compile_options.do_prerun) {
ctx->program->PreRun();
}
......
......@@ -27,8 +27,7 @@ struct ComputationContext;
class CinnComputation {
public:
struct CompileOptions
: public hlir::framework::GraphCompiler::CompileOptions {
struct CompileOptions : public hlir::framework::CompilationContext {
bool use_decomposer = false;
bool do_prerun = true;
bool use_default_passes = true;
......
......@@ -86,7 +86,8 @@ TEST(Decomposer, softmax_decomposer) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
std::vector<float> x(n * c * h * w);
......
......@@ -200,7 +200,8 @@ TEST(Decomposer, BatchNormTrain) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
// set input
......@@ -399,7 +400,8 @@ TEST(Decomposer, BatchNormGrad) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
// set input
......
......@@ -27,6 +27,7 @@
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -208,7 +209,8 @@ void RunAndCheckShape(NetBuilder* builder,
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
std::vector<std::vector<T>> input_vecs_internal;
......
......@@ -38,7 +38,8 @@ TEST(Decomposer, top_k_decomposer) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
std::vector<float> x(10 * 5);
......
......@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -120,10 +121,8 @@ void Interpreter::Impl::Build(const Target& target,
graph->attrs["model_name"] = std::make_shared<absl::any>(model_name);
scope_ = hlir::framework::BuildScope(target, graph, scope_);
graph_compiler_.reset(
new hlir::framework::GraphCompiler(target, scope_, graph));
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
hlir::framework::CompilationContext context(graph, scope_, target);
context.with_instantiate_variables = true;
if (FLAGS_enable_auto_tuner) {
VLOG(4) << "Compile with auto-tune";
auto_schedule::AutoTuner auto_tuner(target, graph.get());
......@@ -131,10 +130,10 @@ void Interpreter::Impl::Build(const Target& target,
graph_compiler_.get());
auto_schedule::TuningOptions tuning_options;
auto_schedule::TuningResult tuning_result = auto_tuner.Tune(tuning_options);
options.Apply(tuning_result);
context.ApplyTuningResult(tuning_result);
}
runtime_program_ =
graph_compiler_->Build(options, std::move(fetch_var_ids)).runtime_program;
graph_compiler_ = std::make_unique<hlir::framework::GraphCompiler>(context);
runtime_program_ = graph_compiler_->Build();
runtime_program_->PreRun();
}
......@@ -150,4 +149,4 @@ Interpreter::Interpreter(
} // namespace cinn::frontend
cinn::frontend::Interpreter::~Interpreter() {}
cinn::frontend::Interpreter::~Interpreter() = default;
......@@ -26,6 +26,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/utils/data_util.h"
......@@ -99,7 +100,8 @@ TEST(net_build, program_execute_multi_elementwise_add) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -139,7 +141,8 @@ TEST(net_build, program_execute_fc) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id()));
......@@ -183,7 +186,8 @@ TEST(net_build, program_execute_multi_elementwise_add_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -224,7 +228,8 @@ TEST(net_build, program_execute_fc_bf16) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(a.id()));
......@@ -285,7 +290,8 @@ TEST(net_build, program_execute_pool2d) {
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -318,7 +324,8 @@ TEST(net_build, program_execute_reverse) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -349,7 +356,8 @@ TEST(net_build, program_execute_gather) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
......@@ -409,7 +417,8 @@ TEST(net_build, program_execute_gather_nd) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input1.id()));
......@@ -469,7 +478,8 @@ TEST(net_build, program_execute_cast) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -523,7 +533,8 @@ TEST(net_build, program_execute_squeeze_case0) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -582,7 +593,8 @@ TEST(net_build, program_execute_squeeze_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -641,7 +653,8 @@ TEST(net_build, program_execute_squeeze_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -699,7 +712,8 @@ TEST(net_build, program_execute_squeeze_case3) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -757,7 +771,8 @@ TEST(net_build, program_execute_squeeze_case4) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -813,7 +828,8 @@ TEST(net_build, program_execute_argsort) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -874,7 +890,8 @@ TEST(net_build, program_execute_sort) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -934,7 +951,8 @@ TEST(net_build, program_execute_arange_float) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id));
......@@ -975,7 +993,8 @@ TEST(net_build, program_execute_arange_int) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(out->id));
......@@ -1018,7 +1037,8 @@ TEST(net_build, program_argmax_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1092,7 +1112,8 @@ TEST(net_build, program_argmax_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1170,7 +1191,8 @@ TEST(net_build, program_argmin_case1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1247,7 +1269,8 @@ TEST(net_build, program_argmin_case2) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1324,7 +1347,8 @@ TEST(net_build, program_execute_repeat_axis_0) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1379,7 +1403,8 @@ TEST(net_build, program_execute_repeat_axis_1) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......@@ -1440,7 +1465,8 @@ TEST(net_build, program_execute_one_hot) {
auto graph = Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>(std::string(input.id()));
......
......@@ -69,7 +69,8 @@ void RunProgram(const Target& target, Program* prog) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
for (size_t i = 0; i < input_names.size(); ++i) {
......
......@@ -73,7 +73,8 @@ TEST(DecomposePass, basic) {
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -61,7 +61,8 @@ std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
for (auto& data : input_data) {
scope->Var<hlir::framework::Tensor>(data.first);
......
......@@ -40,7 +40,8 @@ std::vector<float> RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
......
......@@ -36,6 +36,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -79,14 +80,12 @@ inline void RunGraph(std::shared_ptr<hlir::framework::Graph> graph,
hlir::framework::ApplyPasses(graph.get(), graph_passes);
VLOG(3) << "Graph Viz:\n" << graph->Visualize();
BuildScope(target, graph, scope);
hlir::framework::GraphCompiler::CompileOptions options;
options.attached_code = "";
options.with_instantiate_variables = true;
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build(options,
std::unordered_set<std::string>(
output_ids.begin(), output_ids.end()))
.runtime_program;
hlir::framework::CompilationContext context(graph, scope, target);
context.attached_source_code = "";
context.with_instantiate_variables = true;
context.fetch_var_ids.insert(output_ids.begin(), output_ids.end());
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
......@@ -25,6 +25,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -40,7 +41,8 @@ void RunWithProgram(const Program& program,
hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -117,11 +118,10 @@ class PassTest {
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
auto scope = hlir::framework::BuildScope(target_, graph);
hlir::framework::GraphCompiler gc(target_, scope, graph);
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
auto result = gc.Build(options, std::move(fetch_var_ids));
auto runtime_program = std::move(result.runtime_program);
hlir::framework::CompilationContext context(graph, scope, target_);
context.with_instantiate_variables = true;
hlir::framework::GraphCompiler gc(context);
auto runtime_program = std::move(gc.Build());
for (auto& name : input_names) {
SetInputTensor(name, scope);
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -68,7 +69,8 @@ std::vector<std::vector<float>> RunWithProgram(
hlir::framework::ApplyPasses(graph.get(), {"InferShape", "OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
......
......@@ -24,6 +24,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
......@@ -38,7 +39,8 @@ void RunWithProgram(const Program& program,
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPasses(graph.get(), {"OpFusionPass"});
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
runtime_program->Execute();
}
......
......@@ -23,6 +23,7 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -69,7 +70,8 @@ TEST(syntax, program_execute_multi_elementwise_add) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
scope->Var<hlir::framework::Tensor>("B");
......@@ -88,7 +90,8 @@ TEST(syntax, program_execute_multi_elementwise_add2) {
LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -121,7 +124,8 @@ std::get<2>(programTuple);
auto graph = cinn::frontend::Optimize(program.get(), fetch_ids, target);
scope = BuildScope(target, graph, scope);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope,target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto at = scope->GetTensor("A");
......@@ -133,11 +137,12 @@ std::get<2>(programTuple);
LOG(INFO) << "scope.names: " << Join(scope->var_names(), ",");
const std::string output_name = "fc_0.tmp_2";
auto tensor =
scope->GetTensor(var_map_paddle_to_program.at(output_name)); LOG(INFO) <<
"tensor.shape: " << utils::Join(tensor->shape().data(), ","); auto data =
GetTensorData<float>(tensor, target); for (int i = 0; i < 10; i++) LOG(INFO) <<
"data: " << data[i];
auto tensor = scope->GetTensor(var_map_paddle_to_program.at(output_name));
LOG(INFO) << "tensor.shape: " << utils::Join(tensor->shape().data(), ",");
auto data = GetTensorData<float>(tensor, target);
for (int i = 0; i < 10; i++) {
LOG(INFO) << "data: " << data[i];
}
}
*/
......
......@@ -12,6 +12,7 @@ gather_srcs(
program.cc
parallel_compiler.cc
graph_compiler.cc
graph_compiler_util.cc
graph.cc
node.cc
pass.cc
......@@ -52,5 +53,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore)
#cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
DEPS cinncore)
......@@ -40,90 +40,127 @@ using cinn::common::float16;
std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
utils::RecordEvent("GraphCompiler::Build", utils::EventType::kGraph);
GraphCompiler::CompileOptions options;
options.attached_code = code;
options.with_instantiate_variables = true;
compilation_context_.ApplySourceCode(code);
compilation_context_.with_instantiate_variables = true;
auto&& result = Build(options);
auto&& result = Build(&compilation_context_);
return std::move(result.runtime_program);
}
void GraphCompiler::CompileOptions::Apply(
const auto_schedule::TuningResult& tuning_result) {
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
}
GraphCompiler::CompilationResult GraphCompiler::Build(
const GraphCompiler::CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids,
void* stream) {
CompilationResult GraphCompiler::Build(CompilationContext* context) {
Context::Global().ResetNameId();
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_->VisualizeGroupedGraph(fetch_var_ids.empty() ? fetch_var_ids_
: fetch_var_ids);
context->graph->VisualizeGroupedGraph(context->fetch_var_ids);
if (options.with_instantiate_variables) {
InstantiateVariables();
if (context->with_instantiate_variables) {
InstantiateVariables(context);
}
VLOG(2) << "Compile With Parallel Compiler!";
utils::RecordEvent("GraphCompiler CompileResult",
utils::EventType::kOrdinary);
ParallelCompiler::CompileOptions option;
option.lowered_funcs = options.lowered_funcs;
parallel_compiler_ =
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_);
auto result = (*parallel_compiler_.get())();
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
// Dump compilation result
backends::CompilationInfoDumper dumper(result);
if (options.remove_unused_variables) {
RemoveInvalidVariables(result.instructions);
if (context->stage != CompilationStage::DEFAULT) {
return result;
}
if (options.with_buffer_handle_instruction_inserted) {
if (context->remove_unused_variables) {
RemoveInvalidVariables(context, result.instructions);
}
if (context->with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&result.instructions);
InsertBufferHandlers(context, &result.instructions);
}
VLOG(2) << "Compile With Parallel Compiler Done!";
GraphCompiler::CompilationResult compilation_result;
compilation_result.runtime_program.reset(
new Program(scope_, std::move(result.instructions)));
return compilation_result;
result.runtime_program =
std::make_unique<Program>(context->scope, std::move(result.instructions));
return result;
}
CompilationResult GraphCompiler::Lowering() {
return Lowering(&compilation_context_);
}
CompilationResult GraphCompiler::Lowering(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just lowering!";
context->stage = CompilationStage::LOWERING;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::CodegenAndJit() {
return CodegenAndJit(&compilation_context_);
}
CompilationResult GraphCompiler::CodegenAndJit(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just codegen and jit!";
context->stage = CompilationStage::CODEGEN_AND_JIT;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
CompilationResult GraphCompiler::BuildInstruction() {
return BuildInstruction(&compilation_context_);
}
CompilationResult GraphCompiler::BuildInstruction(CompilationContext* context) {
// Global setting
Context::Global().ResetNameId();
// Setting compile options
VLOG(2) << "Compile With Parallel Compiler! But just build instruction!";
context->stage = CompilationStage::BUILD_INSTRUCTION;
// Compile with parallel compiler
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
return result;
}
void GraphCompiler::InstantiateVariables() {
void GraphCompiler::InstantiateVariables(CompilationContext* context) {
VLOG(3) << "Instantiate all variables on compile-time";
utils::RecordEvent("GraphCompiler MutableData", utils::EventType::kOrdinary);
// All variables reside in scope_, so traverse it to instantiate each one
for (auto& name : scope_->var_names()) {
auto* var = scope_->Var<Tensor>(std::string({name.data(), name.size()}));
for (auto& name : context->scope->var_names()) {
auto* var =
context->scope->Var<Tensor>(std::string({name.data(), name.size()}));
auto& tensor = absl::get<Tensor>(*var);
if (reuse_vars_map_.count(name)) {
auto src_var_name = reuse_vars_map_.at(name);
auto* src_var = scope_->Var<Tensor>(src_var_name);
if (context->reuse_vars_map.count(name)) {
auto src_var_name = context->reuse_vars_map.at(name);
auto* src_var = context->scope->Var<Tensor>(src_var_name);
auto& src_tensor = absl::get<Tensor>(*src_var);
tensor->set_buffer(src_tensor->get_buffer());
} else {
tensor->mutable_data(target_, tensor->type());
tensor->mutable_data(context->target, tensor->type());
}
}
}
void GraphCompiler::RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions) {
// mark all variables are invalid initially
utils::RecordEvent("GraphCompiler RemoveInvalidVariables",
utils::EventType::kOrdinary);
std::unordered_set<std::string> invalid_variables;
auto var_names = scope_->var_names();
auto var_names = context->scope->var_names();
invalid_variables.reserve(var_names.size());
std::transform(
var_names.begin(),
......@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
<< " invalid variables to be removed from scope";
std::for_each(invalid_variables.begin(),
invalid_variables.end(),
[this](const std::string& var_name) {
scope_->EraseVar(var_name);
[context](const std::string& var_name) {
context->scope->EraseVar(var_name);
VLOG(3) << "Variable(" << var_name << ") is erased";
});
}
......@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
}
void GraphCompiler::InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions) {
utils::RecordEvent("GraphCompiler InsertBufferHandlers",
utils::EventType::kOrdinary);
......@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "malloc_buffer_instruction_" + std::to_string(step);
auto malloc_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(),
context->scope.get(),
malloc_var_names,
std::vector<std::string>({}),
function_name);
......@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
auto function_name = "free_buffer_instruction_" + std::to_string(step);
auto free_instr =
std::make_unique<Instruction>(common::DefaultHostTarget(),
scope_.get(),
context->scope.get(),
std::vector<std::string>({}),
free_var_names,
function_name);
......
......@@ -28,6 +28,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
......@@ -46,48 +47,41 @@ namespace framework {
*/
class GraphCompiler final {
public:
GraphCompiler(Target target,
const std::shared_ptr<Scope>& scope,
const std::shared_ptr<Graph>& graph)
: target_(std::move(target)), scope_(scope), graph_(graph) {}
struct CompilationResult {
std::unique_ptr<Program> runtime_program;
};
struct CompileOptions {
std::string attached_code = "";
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// nodes group, it may come from the result of op fusion or graph tuning.
// nodes in a group will be built into an Instruction
std::vector<std::shared_ptr<Graph::Group>> groups;
// corresponding LoweredFuncs of above grouped nodes,
// if it is empty then graph_compiler will generate for them
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// apply results of auto-tune to compile
void Apply(const auto_schedule::TuningResult& tuning_result);
};
GraphCompiler(CompilationContext context) : compilation_context_(context) {}
// Compile with a packing option and result, to be extended easily.
CompilationResult Build(const CompileOptions& options,
std::unordered_set<std::string>&& fetch_var_ids = {},
void* stream = nullptr);
CompilationResult Build(CompilationContext* context);
std::unique_ptr<Program> Build(const std::string& code = "");
const std::shared_ptr<Scope>& GetScope() const { return scope_; }
CompilationResult Lowering();
CompilationResult Lowering(CompilationContext* context);
CompilationResult CodegenAndJit();
CompilationResult CodegenAndJit(CompilationContext* context);
CompilationResult BuildInstruction();
CompilationResult BuildInstruction(CompilationContext* context);
const std::shared_ptr<Scope>& GetScope() const {
return compilation_context_.scope;
}
CompilationContext& GetCompilationContext() { return compilation_context_; }
void SetCompilationContext(const CompilationContext& context) {
compilation_context_ = context;
}
private:
// instantiate all variables on compile time
void InstantiateVariables();
void InstantiateVariables(CompilationContext* context);
// some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation
void RemoveInvalidVariables(
CompilationContext* context,
const std::vector<std::unique_ptr<Instruction>>& instructions);
// find the first and last instruction where a variable used, and mark the
......@@ -102,21 +96,14 @@ class GraphCompiler final {
// firstly used in the next instruction, and insert a buffer free instruction
// applying on variables after no instruction will use them anymore
void InsertBufferHandlers(
CompilationContext* context,
std::vector<std::unique_ptr<Instruction>>* instructions);
private:
// parallel compiler
std::shared_ptr<ParallelCompiler> parallel_compiler_;
Target target_;
std::shared_ptr<Graph> graph_;
std::shared_ptr<Scope> scope_;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std::unordered_set<std::string> fetch_var_ids_;
// map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map_;
CompilationContext compilation_context_;
CINN_DISALLOW_COPY_AND_ASSIGN(GraphCompiler);
};
......
......@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
......@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
ASSERT_EQ(scope->var_names().size(), 6);
EXPECT_NE(scope->FindVar(c->id), nullptr);
GraphCompiler gc(target, scope, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto runtime_program = gc.Build();
ASSERT_EQ(scope->var_names().size(), 3);
EXPECT_EQ(scope->FindVar(c->id), nullptr);
......@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
GraphCompiler gc_disable(target, scope, graph);
GraphCompiler::CompileOptions options;
CompilationContext context_disable(graph, scope, target);
GraphCompiler gc_disable(context_disable);
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto runtime_program_disable = gc_disable.Build(options).runtime_program;
auto runtime_program_disable =
gc_disable.Build(&context_disable).runtime_program;
ASSERT_EQ(runtime_program_disable->size(), 1);
const auto& computation_instr_disable =
runtime_program_disable->GetRunInstructions().front();
......@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// malloc instruction(a, b, d), 2nd -> the real computation
// instruction(add + relu) and 3rd -> free instruction
GraphCompiler gc_enable(target, scope, graph);
options.with_buffer_handle_instruction_inserted = true;
auto runtime_program_enable = gc_enable.Build(options).runtime_program;
CompilationContext context_enable(graph, scope, target);
context_enable.with_buffer_handle_instruction_inserted = true;
GraphCompiler gc_enable(context_enable);
auto runtime_program_enable =
gc_enable.Build(&context_enable).runtime_program;
const auto& instructions = runtime_program_enable->GetRunInstructions();
ASSERT_EQ(instructions.size(), 3);
......@@ -193,7 +198,8 @@ void RunCublas(
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
auto scope = BuildScope(target, graph);
GraphCompiler gc(target, scope, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
auto exe_program = gc.Build();
auto data_a = scope->GetTensor("A");
......@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
RunCublas(64, 128, 128, true, true);
}
TEST(GraphCompilerTest, TestLowering) {
frontend::NetBuilder builder("test_lowering_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.Lowering();
ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestCodegenAndJit) {
frontend::NetBuilder builder("test_codegen_and_jit_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.CodegenAndJit();
ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
}
TEST(GraphCompilerTest, TestBuildInstruction) {
frontend::NetBuilder builder("test_build_instruction_on_graph_compiler");
auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A");
auto b = builder.CreateInput(Float(32), {64}, "B");
auto c = builder.Add(a, b, 1);
auto d = builder.Relu(c);
auto target = common::DefaultNVGPUTarget();
auto program = builder.Build();
auto graph = Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
CompilationContext context(graph, scope, target);
GraphCompiler gc(context);
CompilationResult result = gc.BuildInstruction();
ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
}
#endif
} // namespace framework
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
namespace cinn {
namespace hlir {
namespace framework {
void CompilationContext::ApplyTuningResult(
const auto_schedule::TuningResult& tuning_result) {
// assign options with TuningResult directly
groups.assign(tuning_result.subgraphs.begin(), tuning_result.subgraphs.end());
lowered_funcs.assign(tuning_result.function_groups.begin(),
tuning_result.function_groups.end());
}
void CompilationContext::ApplySourceCode(const std::string& code) {
attached_source_code = code;
}
void CompilationResult::InitCompilationResult(int group_size) {
status = CompilationStatus::SUCCESS;
lowered_funcs.resize(group_size);
source_codes.resize(group_size);
source_ptxs.resize(group_size);
instructions.resize(group_size);
}
} // namespace framework
} // namespace hlir
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/ir/lowered_func.h"
namespace cinn {
namespace hlir {
namespace framework {
// An enum class used to control the compilation stage.
enum class CompilationStage {
// Fully compiled by default, the following compilation result can be
// obtained: lowered_function, source_code, source_ptx, instruction and
// runtime_program.
DEFAULT = 0,
// Just do lowering, we can only get lowered_function from compilation result.
LOWERING = 1,
// Stop after codegen and jit, we can get: lowered_function, source_code and
// source_ptx from compilation result.
CODEGEN_AND_JIT = 2,
// Stop after build instruction, we can get: lowered_function, source_code,
// source_ptx and runtime_program from compilation result.
BUILD_INSTRUCTION = 3,
};
// An enum class used to represent the compilation status.
enum class CompilationStatus {
// Compile successfully.
SUCCESS = 0,
// An unknown error occurred during compilation.
UNKNOWN_FAIL = 1,
// An error occurred during lowering.
LOWERING_FAIL = 2,
// An error occurred during codegen and jit.
CODEGEN_JIT_FAIL = 3,
// An error occurred during build instruction.
INSTUCTION_FAIL = 4,
// An error occurred during build runtime program.
PROGRAM_FAIL = 5,
};
struct CompilationContext {
CompilationContext() = default;
CompilationContext(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Scope>& scope,
const Target& target)
: graph(graph), scope(scope), target(target) {}
std::string attached_source_code = "";
// Compile options.
bool with_instantiate_variables = false;
bool with_buffer_handle_instruction_inserted = false;
bool remove_unused_variables = true;
// Compile stage, full compile by default.
CompilationStage stage = CompilationStage::DEFAULT;
// Compile target.
Target target;
// Computation graph.
std::shared_ptr<Graph> graph;
// Variable scope
std::shared_ptr<Scope> scope;
// Fetch var ids in cinn and the corresponding var nodes will not be fused
// so as to get the result.
std::unordered_set<std::string> fetch_var_ids;
// Map dst reuse var to the src var sharing buffer
absl::flat_hash_map<std::string, std::string> reuse_vars_map;
// Nodes group, it may come from the result of op fusion or graph tuning.
// Nodes in a group will be built into an Instruction.
std::vector<std::shared_ptr<Graph::Group>> groups;
// Corresponding lowered functions of above grouped nodes,
// if it is empty then graph_compiler will generate for them.
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// CUDA stream.
void* stream = nullptr;
// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
void ApplySourceCode(const std::string& code);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
// results are applied.
void ApplyTuningResult(const auto_schedule::TuningResult& tuning_result);
};
struct CompilationResult {
CompilationStatus status;
std::string message;
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
std::vector<std::string> source_codes;
std::vector<std::string> source_ptxs;
std::vector<std::unique_ptr<Instruction>> instructions;
std::unique_ptr<Program> runtime_program;
void InitCompilationResult(int group_size);
};
} // namespace framework
} // namespace hlir
} // namespace cinn
......@@ -26,6 +26,7 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/flags.h"
......@@ -36,135 +37,138 @@ namespace cinn {
namespace hlir {
namespace framework {
ParallelCompiler::CompilationResult ParallelCompiler::operator()() {
if (graph_->fusion_groups.size() == 0) {
hlir::framework::ApplyPasses(graph_.get(), {"BuildNonFusedGroupsPass"});
CompilationResult ParallelCompiler::operator()() {
if (context_->graph->fusion_groups.empty()) {
hlir::framework::ApplyPasses(context_->graph.get(),
{"BuildNonFusedGroupsPass"});
}
// Task Spilt
// init compilation result
result_.InitCompilationResult(context_->graph->fusion_groups.size());
// task spilt
SplitTask();
// launch task
LaunchTask();
// merge instruction
return MergeResult();
// return compilation result
return std::move(result_);
}
void ParallelCompiler::SplitTask() {
CHECK(graph_->fusion_groups.size());
CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() ||
option_.lowered_funcs.size() == 0);
// Assign fusion_group to each task.
// The maximum number of tasks is determined by the number of threads.
// Fusion_group is assigned to tasks in order and continuous.
int fusion_group_size = graph_->fusion_groups.size();
int thread_size = FLAGS_cinn_parallel_compile_thread > 0
? FLAGS_cinn_parallel_compile_thread
: 1;
int group_per_task =
(graph_->fusion_groups.size() + thread_size - 1) / thread_size;
for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) {
Task task(this, scope_, graph_, option_, target_);
task.start_gidx = idx;
task.stop_gidx =
(idx + group_per_task > fusion_group_size ? fusion_group_size
: idx + group_per_task);
tasks_.emplace_back(std::move(task));
CHECK(!context_->graph->fusion_groups.empty());
CHECK(context_->lowered_funcs.empty() ||
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
tasks_.emplace_back(this, context_, i);
}
VLOG(2) << "Split task to " << tasks_.size() << " sub-task!";
}
void ParallelCompiler::RunTask(ParallelCompiler::Task* task) {
VLOG(2) << "Stark run sub-task, Thread Id : " << std::this_thread::get_id();
VLOG(4) << "Start Lowering";
task->Lowering();
VLOG(4) << "Start CodegenAndJit";
task->CodegenAndJit();
VLOG(4) << "Start BuildInstruction";
task->BuildInstruction();
VLOG(2) << "Finish run sub-task, Thread Id : " << std::this_thread::get_id();
void ParallelCompiler::RunTask() {
while (true) {
int idx = GetTaskIdx();
if (idx < 0) {
return;
}
VLOG(4) << "Start run task " << idx
<< " on thread: " << std::this_thread::get_id();
VLOG(4) << "Start lowering on task " << idx;
tasks_[idx].Lowering();
if (context_->stage == CompilationStage::LOWERING) {
VLOG(4) << "Just lowering, finish task " << idx
<< " on thread: " << std::this_thread::get_id();
return;
}
VLOG(4) << "Start CodegenAndJit";
tasks_[idx].CodegenAndJit();
if (context_->stage == CompilationStage::CODEGEN_AND_JIT) {
VLOG(4) << "Just codegen and jit, finish task " << idx
<< " on thread: " << std::this_thread::get_id();
return;
}
VLOG(4) << "Start BuildInstruction";
tasks_[idx].BuildInstruction();
if (context_->stage == CompilationStage::BUILD_INSTRUCTION) {
VLOG(4) << "Just build instruction, finish task " << idx
<< " on thread: " << std::this_thread::get_id();
return;
}
VLOG(4) << "Finish task " << idx
<< " on thread: " << std::this_thread::get_id();
}
}
void ParallelCompiler::LaunchTask() {
// start sub-task.
// multi thread compilation
std::vector<std::thread> threads;
for (int idx = 1; idx < tasks_.size(); ++idx) {
threads.emplace_back(&ParallelCompiler::RunTask, this, &tasks_[idx]);
VLOG(4) << "Compile with " << FLAGS_cinn_parallel_compile_thread
<< " threads";
for (int idx = 1; idx < FLAGS_cinn_parallel_compile_thread; ++idx) {
threads.emplace_back(&ParallelCompiler::RunTask, this);
}
RunTask(&tasks_[0]);
RunTask();
// syncthreads.
for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join));
}
ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() {
ParallelCompiler::CompilationResult res;
for (auto& task : tasks_) {
for (auto& lowered_func : task.lowered_funcs) {
res.lowered_funcs.emplace_back(lowered_func);
}
for (auto& source_code : task.source_codes) {
res.source_codes.emplace_back(source_code);
}
for (auto& source_ptx : task.source_ptxs) {
res.source_ptxs.emplace_back(source_ptx);
}
for (auto& instruction : task.instructions) {
res.instructions.emplace_back(std::move(instruction));
}
}
return res;
}
void ParallelCompiler::Task::Lowering() {
if (options.lowered_funcs.size()) {
CHECK_EQ(options.lowered_funcs.size(), graph->fusion_groups.size());
if (!context->lowered_funcs.empty()) {
CHECK_EQ(context->lowered_funcs.size(),
context->graph->fusion_groups.size());
}
auto& dtype_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
auto& shape_dict =
graph->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, target);
for (int idx = start_gidx; idx < stop_gidx; ++idx) {
if (options.lowered_funcs.size()) {
lowered_funcs.push_back(options.lowered_funcs[idx]);
continue;
}
auto& group = graph->fusion_groups[idx];
VLOG(1) << "Start Lowering Group " << idx << " at "
context->graph
->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
if (!context->lowered_funcs.empty()) {
pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& group = context->graph->fusion_groups[group_id];
VLOG(4) << "Start Lowering Group " << group_id << " at "
<< std::this_thread::get_id() << " :\n"
<< "Group " << idx << " {\n"
<< graph->DebugGroupedGraph(group->CollectNodes()) << "}\n";
<< "Group " << group_id << " {\n"
<< context->graph->DebugGroupedGraph(group->CollectNodes())
<< "}\n";
auto lowered_group = op_lowerer.Lower(group);
CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
lowered_funcs.emplace_back(std::move(lowered_group));
pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group);
}
}
void ParallelCompiler::Task::CodegenAndJit() {
VLOG(2) << "Start Codegen and JIT with Group [" << start_gidx << "-"
<< stop_gidx << ") at thread" << std::this_thread::get_id();
VLOG(2) << "Start Codegen and JIT on Group " << group_id
<< " at thread: " << std::this_thread::get_id();
// build module
ir::Module::Builder builder(common::UniqName("module"), target);
for (auto& func : lowered_funcs) {
CHECK_EQ(func.size(), 1);
builder.AddFunction(func[0]);
ir::Module::Builder builder(common::UniqName("module"), context->target);
for (auto& func : pcompiler->result_.lowered_funcs[group_id]) {
builder.AddFunction(func);
}
auto ir_module = builder.Build();
if (target == common::DefaultNVGPUTarget()) {
if (context->target == common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitCudaAndHostModule(ir_module);
auto hmodule = std::get<0>(splited_module);
auto dmodule = std::get<1>(splited_module);
VLOG(3) << "Host Code:\n" << hmodule;
VLOG(3) << "Device Code:\n" << dmodule;
backends::CodeGenCUDA_Dev codegen(target);
auto cuda_c = codegen.Compile(dmodule);
VLOG(4) << "Host Code:\n" << hmodule;
VLOG(4) << "Device Code:\n" << dmodule;
std::string cuda_c;
if (context->attached_source_code.empty()) {
backends::CodeGenCUDA_Dev codegen(context->target);
cuda_c = codegen.Compile(dmodule);
} else {
VLOG(4) << "Codegen and jit with attached source code.";
cuda_c = context->attached_source_code;
}
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule;
pcompiler->result_.source_codes[group_id] = cuda_c;
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
......@@ -172,15 +176,12 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
pcompiler->result_.source_ptxs[group_id] = ptx;
// load cumodule
cumodule.reset(new CUDAModule(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX));
source_codes.emplace_back(std::move(cuda_c));
source_ptxs.emplace_back(std::move(ptx));
cumodule = std::make_unique<CUDAModule>(ptx,
compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX);
// register kernel
backends::RuntimeSymbols symbols;
......@@ -201,25 +202,30 @@ void ParallelCompiler::Task::CodegenAndJit() {
void ParallelCompiler::Task::BuildInstruction() {
// create instruction.
for (int idx = start_gidx; idx < stop_gidx; ++idx) {
VLOG(2) << "Start BuildInstruction of Group " << idx << " at "
<< std::this_thread::get_id();
auto& group = graph->fusion_groups[idx];
CHECK(group->input_names.size() > 0 || group->output_names.size() > 0);
auto instr =
std::unique_ptr<Instruction>(new Instruction(target,
scope.get(),
group->input_names,
group->output_names,
group->GetFuncName()));
auto fn_ptr = engine->Lookup(group->GetFuncName());
CHECK(fn_ptr) << "Can't find jit function : " << group->GetFuncName();
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr),
group->GetFuncName());
instr->Finalize();
instructions.push_back(std::move(instr));
VLOG(4) << "Start BuildInstruction of Group " << group_id
<< " at thread: " << std::this_thread::get_id();
auto& group = context->graph->fusion_groups[group_id];
CHECK(!group->input_names.empty() || !group->output_names.empty());
auto instr = std::make_unique<Instruction>(context->target,
context->scope.get(),
group->input_names,
group->output_names,
group->GetFuncName());
auto fn_ptr = engine->Lookup(group->GetFuncName());
CHECK(fn_ptr) << "Can't find jit function : " << group->GetFuncName();
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());
instr->Finalize();
pcompiler->result_.instructions[group_id] = std::move(instr);
}
int ParallelCompiler::GetTaskIdx() {
std::lock_guard<std::mutex> lock(mtx_);
if (task_idx_ < tasks_.size()) {
return task_idx_++;
} else {
return -1;
}
}
......
......@@ -19,6 +19,7 @@
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/ir/lowered_func.h"
......@@ -31,44 +32,20 @@ namespace framework {
class ParallelCompiler {
public:
struct CompileOptions {
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
};
struct CompilationResult {
// Lower result
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
// Host/CUDA codegen result
std::vector<std::string> source_codes;
// CUDA ptx result
std::vector<std::string> source_ptxs;
// Instruction result
std::vector<std::unique_ptr<Instruction>> instructions;
};
struct Task {
Task(ParallelCompiler* p,
std::shared_ptr<Scope>& s, // NOLINT
std::shared_ptr<Graph>& g, // NOLINT
const CompileOptions& cp,
const Target& t)
: compiler(p), scope(s), graph(g), options(cp), target(t) {}
Task(ParallelCompiler* compiler, CompilationContext* context, int group_id)
: pcompiler(compiler), context(context), group_id(group_id) {}
void Lowering();
void CodegenAndJit();
void BuildInstruction();
const Target target;
ParallelCompiler* compiler;
std::shared_ptr<Scope> scope;
std::shared_ptr<Graph> graph;
const CompileOptions& options;
ParallelCompiler* pcompiler;
CompilationContext* context;
CompilationStatus status = CompilationStatus::SUCCESS;
std::string message;
int start_gidx;
int stop_gidx;
std::vector<std::unique_ptr<Instruction>> instructions;
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
std::vector<std::string> source_codes;
std::vector<std::string> source_ptxs;
int group_id;
std::unique_ptr<backends::ExecutionEngine> engine;
#ifdef CINN_WITH_CUDA
......@@ -76,25 +53,22 @@ class ParallelCompiler {
#endif
};
explicit ParallelCompiler(std::shared_ptr<Scope>& scope, // NOLINT
std::shared_ptr<Graph>& graph, // NOLINT
const CompileOptions& option,
const common::Target& target)
: scope_(scope), graph_(graph), option_(option), target_(target) {}
~ParallelCompiler() {}
explicit ParallelCompiler(CompilationContext* context) : context_(context) {}
~ParallelCompiler() = default;
CompilationResult operator()();
private:
void SplitTask();
void LaunchTask();
void RunTask(Task* task);
CompilationResult MergeResult();
void RunTask();
int GetTaskIdx();
int task_idx_{0};
std::mutex mtx_;
std::vector<Task> tasks_;
const common::Target target_;
const CompileOptions& option_;
std::shared_ptr<Scope> scope_;
std::shared_ptr<Graph> graph_;
CompilationContext* context_;
CompilationResult result_;
};
} // namespace framework
......
......@@ -20,6 +20,7 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
namespace cinn {
namespace hlir {
......@@ -35,9 +36,9 @@ TEST(ParallelCompilerTest, Add_TEST_0) {
auto graph = std::make_shared<Graph>(program, target);
auto scope = BuildScope(target, graph);
ParallelCompiler::CompileOptions option;
ParallelCompiler pc(scope, graph, option, target);
auto runtime_program = pc();
CompilationContext context(graph, scope, target);
ParallelCompiler pc(&context);
auto compilation_result = pc();
}
TEST(ParallelCompilerTest, Conv2d_Test_0) {
......@@ -53,9 +54,9 @@ TEST(ParallelCompilerTest, Conv2d_Test_0) {
auto graph = frontend::Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
ParallelCompiler::CompileOptions option;
ParallelCompiler pc(scope, graph, option, target);
auto runtime_program = pc();
CompilationContext context(graph, scope, target);
ParallelCompiler pc(&context);
auto compilation_result = pc();
}
TEST(ParallelCompilerTest, Matmul_Test_0) {
......@@ -71,9 +72,9 @@ TEST(ParallelCompilerTest, Matmul_Test_0) {
auto graph = frontend::Optimize(&program, {}, target);
auto scope = BuildScope(target, graph);
ParallelCompiler::CompileOptions option;
ParallelCompiler pc(scope, graph, option, target);
auto runtime_program = pc();
CompilationContext context(graph, scope, target);
ParallelCompiler pc(&context);
auto compilation_result = pc();
}
} // namespace framework
......
......@@ -76,7 +76,8 @@ TEST(conv, conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -122,7 +123,8 @@ TEST(conv_relu_conv, conv_relu_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -171,7 +173,8 @@ TEST(conv_add_conv, conv_add_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -227,7 +230,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -283,7 +287,8 @@ TEST(conv_pool2d_conv, conv_pool2d_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -334,7 +339,8 @@ TEST(conv_softmax_conv, conv_softmax_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -382,7 +388,8 @@ TEST(conv_sigmoid_conv, conv_sigmoid_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -434,7 +441,8 @@ TEST(conv_mul_conv, conv_mul_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -46,7 +46,8 @@ void RunTest(const Target& target,
const std::shared_ptr<Graph>& graph,
const std::vector<std::string>& input_names) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
for (size_t i = 0; i < input_names.size(); ++i) {
scope->Var<hlir::framework::Tensor>(input_names[i]);
......
......@@ -71,7 +71,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
hlir::framework::ApplyPass(graph.get(), "BuildNonFusedGroupsPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -115,7 +116,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
hlir::framework::ApplyPass(graph.get(), "BuildNonFusedGroupsPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -180,7 +182,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......
......@@ -57,7 +57,8 @@ TEST(const_conv, const_conv) {
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......@@ -101,7 +102,8 @@ TEST(const_bn, const_bn) {
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
......
......@@ -46,7 +46,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir::framework::ApplyPasses(graph.get(), passes);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (auto& data : input_data) {
......
......@@ -46,7 +46,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......@@ -72,7 +73,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......
......@@ -45,7 +45,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......@@ -71,7 +72,8 @@ void RunModelTest(Program& program, // NOLINT
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (int idx = 0; idx < inputs.size(); ++idx) {
......
......@@ -80,7 +80,8 @@ TEST(complex2, complex2) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -135,7 +136,8 @@ TEST(complex1, complex1) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -172,7 +174,8 @@ TEST(fuse_add_relu, fuse_add_relu) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -210,7 +213,8 @@ TEST(fuse_add, fuse_add) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -268,7 +272,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -319,7 +324,8 @@ TEST(fuse_conv_add, fuse_conv_add) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -377,7 +383,8 @@ TEST(conv_add_mul, conv_add_mul) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -426,7 +433,8 @@ TEST(fuse_conv_add1, fuse_conv_add1) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -465,7 +473,8 @@ TEST(transpose_reshape_concat, transpose_reshape_concat) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -517,7 +526,8 @@ TEST(conv_bn, conv_bn) {
hlir::framework::ApplyPass(graph.get(), "OpFusion");
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -30,7 +30,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir::framework::ApplyPasses(graph.get(), passes);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto run_program = gc.Build();
for (auto& data : input_data) {
......
......@@ -66,7 +66,8 @@ TEST(batch_norm_meta, batch_norm_meta) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -104,7 +105,8 @@ TEST(reduction, reduce) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -136,7 +138,8 @@ TEST(Compare, Compare) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -203,15 +203,14 @@ void BindFrontend(pybind11::module *m) {
auto graph = Optimize(&self, fetch_ids, target, passes);
scope = hlir::framework::BuildScope(target, graph, scope);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
// Keep compile option same as paddle
hlir::framework::GraphCompiler::CompileOptions options;
options.with_instantiate_variables = true;
options.remove_unused_variables = false;
auto gc_fetch_ids = fetch_ids;
const auto &result = gc.Build(options, std::move(gc_fetch_ids));
const auto &program = result.runtime_program;
context.with_instantiate_variables = true;
context.remove_unused_variables = false;
context.fetch_var_ids = fetch_ids;
hlir::framework::GraphCompiler gc(context);
const auto &program = gc.Build();
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
......@@ -305,7 +304,8 @@ void BindFrontend(pybind11::module *m) {
hlir::framework::ApplyPass(g.get(), "InferShape");
std::shared_ptr<hlir::framework::Scope> scope =
hlir::framework::BuildScope(target, g);
hlir::framework::GraphCompiler gc(target, scope, g);
hlir::framework::CompilationContext context(g, scope, target);
hlir::framework::GraphCompiler gc(context);
auto program = gc.Build();
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
......@@ -354,7 +354,8 @@ void BindFrontend(pybind11::module *m) {
std::shared_ptr<hlir::framework::Scope> scope =
hlir::framework::BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto program = gc.Build(code);
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
......
......@@ -32,6 +32,7 @@
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -64,6 +65,8 @@ using ::cinn::common::Target;
using ::cinn::frontend::Optimize;
using ::cinn::frontend::paddle::InplaceOutSuffix;
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::CompilationContext;
using ::cinn::hlir::framework::CompilationResult;
using ::cinn::hlir::framework::GraphCompiler;
using inference::analysis::Dot;
using ir::Graph;
......@@ -318,12 +321,11 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
<< cinn_graph->Visualize();
auto scope = BuildScope(target, cinn_graph);
auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
CompilationContext context(cinn_graph, scope, target);
auto graph_compiler = std::make_unique<GraphCompiler>(context);
context.with_instantiate_variables = false;
if (!FLAGS_enable_pe_launch_cinn) {
options.with_buffer_handle_instruction_inserted = true;
context.with_buffer_handle_instruction_inserted = true;
}
std::unique_ptr<AutoTuner> auto_tuner;
if (FLAGS_enable_cinn_auto_tune) {
......@@ -333,14 +335,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
::cinn::auto_schedule::TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto tuning_result = auto_tuner->Tune(tuning_options);
options.Apply(tuning_result);
context.ApplyTuningResult(tuning_result);
}
auto compiled_res =
graph_compiler->Build(options, std::move(fetch_ids), stream);
context.fetch_var_ids = std::move(fetch_ids);
context.stream = stream;
auto compiled_res = graph_compiler->Build();
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler),
std::move(auto_tuner),
std::move(compiled_res.runtime_program),
std::move(compiled_res),
scope,
symbol.var_model_to_program_map()};
compiled_obj->cached_index = compiled_num;
......
......@@ -97,10 +97,11 @@ class TestConstantOpShape(TestCaseHelper):
{
"shape": [1024],
},
# Update: stack over flow while compiling
# very slow for the shape 2048
{
"shape": [2048],
},
# {
# "shape": [2048],
# },
{
"shape": [1, 1, 1, 1],
},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册