From 29a3d2db81a6ccc6b9147c20418964c33e7b05fb Mon Sep 17 00:00:00 2001 From: Fisher Date: Wed, 23 Aug 2023 19:17:41 +0800 Subject: [PATCH] [CINN] Dump compilation info by fusion group while compiling (#56530) Support to dump information in stages according to the fusion group during the compilation process, instead of after the compilation is completely completed. --- paddle/cinn/backends/compiler.cc | 42 +++++++++++++++++++ paddle/cinn/backends/compiler.h | 18 ++++++-- paddle/cinn/hlir/framework/graph_compiler.cc | 3 -- .../cinn/hlir/framework/parallel_compiler.cc | 27 ++++++------ .../cinn/hlir/framework/parallel_compiler.h | 4 +- 5 files changed, 73 insertions(+), 21 deletions(-) diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index 57b41163707..cd6a38ec16c 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -41,6 +41,48 @@ using ir::Module; static constexpr int DebugLogMaxLen = 30000; +void CompilationInfoDumper::DumpLoweredFuncByGroupIndex( + const ir::LoweredFunc& lowered_func, const int gidx) { + if (FLAGS_cinn_dump_group_lowered_func.empty() || + lowered_func.get() == nullptr) { + return; + } + std::stringstream content; + content << lowered_func; + Dump(FLAGS_cinn_dump_group_lowered_func, + gidx, + "lowered_function.txt", + content.str()); +} + +void CompilationInfoDumper::DumpSourceCodeByGroupIndex( + const std::string& source_code, const int gidx) { + if (FLAGS_cinn_dump_group_source_code.empty()) { + return; + } + Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code); +} + +void CompilationInfoDumper::DumpPtxCodeByGroupIndex( + const std::string& source_ptx, const int gidx) { + if (FLAGS_cinn_dump_group_ptx.empty()) { + return; + } + Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx); +} + +void CompilationInfoDumper::DumpInstructionByGroupIndex( + const std::unique_ptr& instr, + const int gidx) { + if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) { + return; + } + Dump(FLAGS_cinn_dump_group_instruction, + gidx, + "instruction.txt", + instr->DumpInstruction()); +} + void CompilationInfoDumper::DumpLoweredFunc() { if (FLAGS_cinn_dump_group_lowered_func.empty()) { return; diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h index e708ea9cc3c..8b09573b522 100644 --- a/paddle/cinn/backends/compiler.h +++ b/paddle/cinn/backends/compiler.h @@ -51,15 +51,25 @@ class CompilationInfoDumper { DumpInstruction(); } + static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func, + const int gidx); + static void DumpSourceCodeByGroupIndex(const std::string& source_code, + const int gidx); + static void DumpPtxCodeByGroupIndex(const std::string& source_ptx, + const int gidx); + static void DumpInstructionByGroupIndex( + const std::unique_ptr& instr, + const int gidx); + private: void DumpLoweredFunc(); void DumpSourceCode(); void DumpPtxCode(); void DumpInstruction(); - void Dump(const std::string& base_path, - const int idx, - const std::string& file_name, - const std::string& content); + static void Dump(const std::string& base_path, + const int idx, + const std::string& file_name, + const std::string& content); const hlir::framework::CompilationResult& info_; }; diff --git a/paddle/cinn/hlir/framework/graph_compiler.cc b/paddle/cinn/hlir/framework/graph_compiler.cc index d1834101fe8..b316e1d95ca 100644 --- a/paddle/cinn/hlir/framework/graph_compiler.cc +++ b/paddle/cinn/hlir/framework/graph_compiler.cc @@ -64,9 +64,6 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) { parallel_compiler_ = std::make_shared(context); CompilationResult result = (*parallel_compiler_.get())(); - // Dump compilation result - backends::CompilationInfoDumper dumper(result); - if (context->stage != CompilationStage::DEFAULT) { return result; } diff --git a/paddle/cinn/hlir/framework/parallel_compiler.cc b/paddle/cinn/hlir/framework/parallel_compiler.cc index 759ce719564..2ded4ffd917 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler.cc +++ b/paddle/cinn/hlir/framework/parallel_compiler.cc @@ -58,7 +58,7 @@ void ParallelCompiler::SplitTask() { context_->graph->fusion_groups.size() == context_->lowered_funcs.size()); for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) { - tasks_.emplace_back(this, context_, i); + tasks_.emplace_back(i, this, context_); } } @@ -114,20 +114,17 @@ void ParallelCompiler::Task::Lowering() { if (!context->lowered_funcs.empty()) { CHECK_EQ(context->lowered_funcs.size(), context->graph->fusion_groups.size()); - } - auto& dtype_dict = - context->graph->GetMutableAttrs>( - "inferdtype"); - auto& shape_dict = - context->graph - ->GetMutableAttrs>( - "infershape"); - - OpLowerer op_lowerer(dtype_dict, shape_dict, context->target); - if (!context->lowered_funcs.empty()) { pcompiler->result_.lowered_funcs[group_id] = context->lowered_funcs[group_id]; } else { + auto& dtype_dict = + context->graph->GetMutableAttrs>( + "inferdtype"); + auto& shape_dict = + context->graph + ->GetMutableAttrs>( + "infershape"); + OpLowerer op_lowerer(dtype_dict, shape_dict, context->target); auto& group = context->graph->fusion_groups[group_id]; VLOG(4) << "Start Lowering Group " << group_id << " at " << std::this_thread::get_id() << " :\n" @@ -138,6 +135,8 @@ void ParallelCompiler::Task::Lowering() { CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!"; pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group); } + backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex( + pcompiler->result_.lowered_funcs[group_id].front(), group_id); } void ParallelCompiler::Task::CodegenAndJit() { @@ -168,6 +167,8 @@ void ParallelCompiler::Task::CodegenAndJit() { } CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" << dmodule; + backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c, + group_id); pcompiler->result_.source_codes[group_id] = cuda_c; cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); @@ -176,6 +177,7 @@ void ParallelCompiler::Task::CodegenAndJit() { backends::nvrtc::Compiler compiler; auto ptx = compiler(cuda_c); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; + backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id); pcompiler->result_.source_ptxs[group_id] = ptx; // load cumodule cumodule = std::make_unique(ptx, @@ -217,6 +219,7 @@ void ParallelCompiler::Task::BuildInstruction() { instr->SetLoweredFunc(reinterpret_cast(fn_ptr), group->GetFuncName()); instr->Finalize(); + backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id); pcompiler->result_.instructions[group_id] = std::move(instr); } diff --git a/paddle/cinn/hlir/framework/parallel_compiler.h b/paddle/cinn/hlir/framework/parallel_compiler.h index c2d41fb6215..7eb22b1fbc3 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler.h +++ b/paddle/cinn/hlir/framework/parallel_compiler.h @@ -33,8 +33,8 @@ namespace framework { class ParallelCompiler { public: struct Task { - Task(ParallelCompiler* compiler, CompilationContext* context, int group_id) - : pcompiler(compiler), context(context), group_id(group_id) {} + Task(int group_id, ParallelCompiler* compiler, CompilationContext* context) + : group_id(group_id), pcompiler(compiler), context(context) {} void Lowering(); void CodegenAndJit(); void BuildInstruction(); -- GitLab