未验证 提交 29a3d2db 编写于 作者: F Fisher 提交者: GitHub

[CINN] Dump compilation info by fusion group while compiling (#56530)

Support to dump information in stages according to the fusion group during the compilation process, instead of after the compilation is completely completed.
上级 9ed58bff
...@@ -41,6 +41,48 @@ using ir::Module; ...@@ -41,6 +41,48 @@ using ir::Module;
static constexpr int DebugLogMaxLen = 30000; static constexpr int DebugLogMaxLen = 30000;
void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
}
std::stringstream content;
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
"lowered_function.txt",
content.str());
}
void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code);
}
void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx);
}
void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
"instruction.txt",
instr->DumpInstruction());
}
void CompilationInfoDumper::DumpLoweredFunc() { void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) { if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return; return;
......
...@@ -51,12 +51,22 @@ class CompilationInfoDumper { ...@@ -51,12 +51,22 @@ class CompilationInfoDumper {
DumpInstruction(); DumpInstruction();
} }
static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx);
private: private:
void DumpLoweredFunc(); void DumpLoweredFunc();
void DumpSourceCode(); void DumpSourceCode();
void DumpPtxCode(); void DumpPtxCode();
void DumpInstruction(); void DumpInstruction();
void Dump(const std::string& base_path, static void Dump(const std::string& base_path,
const int idx, const int idx,
const std::string& file_name, const std::string& file_name,
const std::string& content); const std::string& content);
......
...@@ -64,9 +64,6 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) { ...@@ -64,9 +64,6 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) {
parallel_compiler_ = std::make_shared<ParallelCompiler>(context); parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())(); CompilationResult result = (*parallel_compiler_.get())();
// Dump compilation result
backends::CompilationInfoDumper dumper(result);
if (context->stage != CompilationStage::DEFAULT) { if (context->stage != CompilationStage::DEFAULT) {
return result; return result;
} }
......
...@@ -58,7 +58,7 @@ void ParallelCompiler::SplitTask() { ...@@ -58,7 +58,7 @@ void ParallelCompiler::SplitTask() {
context_->graph->fusion_groups.size() == context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size()); context_->lowered_funcs.size());
for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) { for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
tasks_.emplace_back(this, context_, i); tasks_.emplace_back(i, this, context_);
} }
} }
...@@ -114,7 +114,9 @@ void ParallelCompiler::Task::Lowering() { ...@@ -114,7 +114,9 @@ void ParallelCompiler::Task::Lowering() {
if (!context->lowered_funcs.empty()) { if (!context->lowered_funcs.empty()) {
CHECK_EQ(context->lowered_funcs.size(), CHECK_EQ(context->lowered_funcs.size(),
context->graph->fusion_groups.size()); context->graph->fusion_groups.size());
} pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& dtype_dict = auto& dtype_dict =
context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>( context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype"); "inferdtype");
...@@ -122,12 +124,7 @@ void ParallelCompiler::Task::Lowering() { ...@@ -122,12 +124,7 @@ void ParallelCompiler::Task::Lowering() {
context->graph context->graph
->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>( ->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape"); "infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, context->target); OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
if (!context->lowered_funcs.empty()) {
pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& group = context->graph->fusion_groups[group_id]; auto& group = context->graph->fusion_groups[group_id];
VLOG(4) << "Start Lowering Group " << group_id << " at " VLOG(4) << "Start Lowering Group " << group_id << " at "
<< std::this_thread::get_id() << " :\n" << std::this_thread::get_id() << " :\n"
...@@ -138,6 +135,8 @@ void ParallelCompiler::Task::Lowering() { ...@@ -138,6 +135,8 @@ void ParallelCompiler::Task::Lowering() {
CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!"; CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group); pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group);
} }
backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
pcompiler->result_.lowered_funcs[group_id].front(), group_id);
} }
void ParallelCompiler::Task::CodegenAndJit() { void ParallelCompiler::Task::CodegenAndJit() {
...@@ -168,6 +167,8 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -168,6 +167,8 @@ void ParallelCompiler::Task::CodegenAndJit() {
} }
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule; << dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c,
group_id);
pcompiler->result_.source_codes[group_id] = cuda_c; pcompiler->result_.source_codes[group_id] = cuda_c;
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
...@@ -176,6 +177,7 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -176,6 +177,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler; backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c); auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id);
pcompiler->result_.source_ptxs[group_id] = ptx; pcompiler->result_.source_ptxs[group_id] = ptx;
// load cumodule // load cumodule
cumodule = std::make_unique<CUDAModule>(ptx, cumodule = std::make_unique<CUDAModule>(ptx,
...@@ -217,6 +219,7 @@ void ParallelCompiler::Task::BuildInstruction() { ...@@ -217,6 +219,7 @@ void ParallelCompiler::Task::BuildInstruction() {
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName()); instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());
instr->Finalize(); instr->Finalize();
backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id);
pcompiler->result_.instructions[group_id] = std::move(instr); pcompiler->result_.instructions[group_id] = std::move(instr);
} }
......
...@@ -33,8 +33,8 @@ namespace framework { ...@@ -33,8 +33,8 @@ namespace framework {
class ParallelCompiler { class ParallelCompiler {
public: public:
struct Task { struct Task {
Task(ParallelCompiler* compiler, CompilationContext* context, int group_id) Task(int group_id, ParallelCompiler* compiler, CompilationContext* context)
: pcompiler(compiler), context(context), group_id(group_id) {} : group_id(group_id), pcompiler(compiler), context(context) {}
void Lowering(); void Lowering();
void CodegenAndJit(); void CodegenAndJit();
void BuildInstruction(); void BuildInstruction();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册