未验证 提交 29a3d2db 编写于 作者: F Fisher 提交者: GitHub

[CINN] Dump compilation info by fusion group while compiling (#56530)

Support to dump information in stages according to the fusion group during the compilation process, instead of after the compilation is completely completed.
上级 9ed58bff
......@@ -41,6 +41,48 @@ using ir::Module;
static constexpr int DebugLogMaxLen = 30000;
void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
}
std::stringstream content;
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
"lowered_function.txt",
content.str());
}
void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code);
}
void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx);
}
void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
"instruction.txt",
instr->DumpInstruction());
}
void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
......
......@@ -51,12 +51,22 @@ class CompilationInfoDumper {
DumpInstruction();
}
static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx);
private:
void DumpLoweredFunc();
void DumpSourceCode();
void DumpPtxCode();
void DumpInstruction();
void Dump(const std::string& base_path,
static void Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content);
......
......@@ -64,9 +64,6 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) {
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();
// Dump compilation result
backends::CompilationInfoDumper dumper(result);
if (context->stage != CompilationStage::DEFAULT) {
return result;
}
......
......@@ -58,7 +58,7 @@ void ParallelCompiler::SplitTask() {
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
tasks_.emplace_back(this, context_, i);
tasks_.emplace_back(i, this, context_);
}
}
......@@ -114,7 +114,9 @@ void ParallelCompiler::Task::Lowering() {
if (!context->lowered_funcs.empty()) {
CHECK_EQ(context->lowered_funcs.size(),
context->graph->fusion_groups.size());
}
pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& dtype_dict =
context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
"inferdtype");
......@@ -122,12 +124,7 @@ void ParallelCompiler::Task::Lowering() {
context->graph
->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
"infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
if (!context->lowered_funcs.empty()) {
pcompiler->result_.lowered_funcs[group_id] =
context->lowered_funcs[group_id];
} else {
auto& group = context->graph->fusion_groups[group_id];
VLOG(4) << "Start Lowering Group " << group_id << " at "
<< std::this_thread::get_id() << " :\n"
......@@ -138,6 +135,8 @@ void ParallelCompiler::Task::Lowering() {
CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group);
}
backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
pcompiler->result_.lowered_funcs[group_id].front(), group_id);
}
void ParallelCompiler::Task::CodegenAndJit() {
......@@ -168,6 +167,8 @@ void ParallelCompiler::Task::CodegenAndJit() {
}
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c,
group_id);
pcompiler->result_.source_codes[group_id] = cuda_c;
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
......@@ -176,6 +177,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id);
pcompiler->result_.source_ptxs[group_id] = ptx;
// load cumodule
cumodule = std::make_unique<CUDAModule>(ptx,
......@@ -217,6 +219,7 @@ void ParallelCompiler::Task::BuildInstruction() {
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());
instr->Finalize();
backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id);
pcompiler->result_.instructions[group_id] = std::move(instr);
}
......
......@@ -33,8 +33,8 @@ namespace framework {
class ParallelCompiler {
public:
struct Task {
Task(ParallelCompiler* compiler, CompilationContext* context, int group_id)
: pcompiler(compiler), context(context), group_id(group_id) {}
Task(int group_id, ParallelCompiler* compiler, CompilationContext* context)
: group_id(group_id), pcompiler(compiler), context(context) {}
void Lowering();
void CodegenAndJit();
void BuildInstruction();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册