未验证 提交 39b59603 编写于 作者: F Fisher 提交者: GitHub

[CINN] Dump more compilation result and optimize parallel compiler flags (#55935)

1. `Parallel Compiler`:
    - 合并`FLAGS_cinn_parallel_compile_size`和`FLAGS_cinn_parallel_compile_thread`,通过`FLAGS_cinn_parallel_compile_thread`即可指定编译时使用的线程数,所有的`fusion_groups`将会平均分配到可用的线程上
    - 增强编译完成后返回的信息,除`instruction`外,将`lowered_function`、`source_code`、`source_ptx`返回,供上层进一步使用
2. Debug信息:
    - 新增`FLAGS_ cinn_dump_group_lowered_func`、`FLAGS_cinn_dump_group_source_code`、`FLAGS_ cinn_dump_group_ptx`、`FLAGS_ cinn_dump_group_instruction`,可分别按`fusion_groups`储存编译的每个阶段中的中间代码
    - 重新整理`graph_visualization`,所有的可视化图、单测代码均能正确分组储存
3. Bug修复:
    - 修复`MakeDirectory`不能正确创建文件夹的问题
4. 其他:
    - 清除了一些无用代码
上级 469a0392
...@@ -55,7 +55,6 @@ DEFINE_string(resnet50_model_dir, ...@@ -55,7 +55,6 @@ DEFINE_string(resnet50_model_dir,
DEFINE_int32(evaluate_knobs, DEFINE_int32(evaluate_knobs,
-1, -1,
"the options to control which schedule tests will be run."); "the options to control which schedule tests will be run.");
DECLARE_int32(cinn_parallel_compile_size);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -78,8 +77,6 @@ class PerformanceTester : public ::testing::Test { ...@@ -78,8 +77,6 @@ class PerformanceTester : public ::testing::Test {
std::bitset<3> evaluate_knobs = 0UL; std::bitset<3> evaluate_knobs = 0UL;
}; };
void SetUp() override { FLAGS_cinn_parallel_compile_size = 0; }
void Evaluate(const frontend::Program& program) { void Evaluate(const frontend::Program& program) {
if (FLAGS_evaluate_knobs >= 0) { if (FLAGS_evaluate_knobs >= 0) {
options_.evaluate_knobs = FLAGS_evaluate_knobs; options_.evaluate_knobs = FLAGS_evaluate_knobs;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h" #include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h" #include "paddle/cinn/backends/codegen_cuda_host.h"
...@@ -29,6 +30,10 @@ ...@@ -29,6 +30,10 @@
#endif #endif
DECLARE_string(cinn_source_code_save_path); DECLARE_string(cinn_source_code_save_path);
DECLARE_string(cinn_dump_group_lowered_func);
DECLARE_string(cinn_dump_group_source_code);
DECLARE_string(cinn_dump_group_ptx);
DECLARE_string(cinn_dump_group_instruction);
namespace cinn { namespace cinn {
namespace backends { namespace backends {
...@@ -36,6 +41,81 @@ using ir::Module; ...@@ -36,6 +41,81 @@ using ir::Module;
static constexpr int DebugLogMaxLen = 30000; static constexpr int DebugLogMaxLen = 30000;
void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
}
for (int idx = 0; idx < info_.lowered_funcs.size(); ++idx) {
std::stringstream content;
content << info_.lowered_funcs[idx].front();
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
"lowered_function.txt",
content.str());
}
}
void CompilationInfoDumper::DumpSourceCode() {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
for (int idx = 0; idx < info_.source_codes.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_source_code,
idx,
"source_code.cu",
info_.source_codes[idx]);
}
}
void CompilationInfoDumper::DumpPtxCode() {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
for (int idx = 0; idx < info_.source_ptxs.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_ptx,
idx,
"source_ptx.ptx",
info_.source_ptxs[idx]);
}
}
void CompilationInfoDumper::DumpInstruction() {
if (FLAGS_cinn_dump_group_instruction.empty()) {
return;
}
for (int idx = 0; idx < info_.instructions.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_instruction,
idx,
"instruction.txt",
info_.instructions[idx]->DumpInstruction());
}
}
void CompilationInfoDumper::Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content) {
auto dump_path =
utils::StringFormat("%s/fusion_group_%d", base_path.c_str(), idx);
if (!hlir::framework::MakeDirectory(
dump_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG(WARNING) << "Failed to make directory: \"" << dump_path
<< "\", the instruction for this group will not dump.";
} else {
auto dump_file =
utils::StringFormat("%s/%s", dump_path.c_str(), file_name.c_str());
VLOG(7) << "Dump instruction to: " << dump_file;
std::ofstream of(dump_file, std::ios_base::out);
if (of.is_open()) {
of << content;
of.close();
} else {
LOG(WARNING) << "Failed to open file: " << dump_file
<< ", please check your path.";
}
}
}
SourceCodePrint::SourceCodePrint() { SourceCodePrint::SourceCodePrint() {
if (!FLAGS_cinn_source_code_save_path.empty()) { if (!FLAGS_cinn_source_code_save_path.empty()) {
LOG(INFO) LOG(INFO)
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/cinn/backends/llvm/codegen_llvm.h" #include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/execution_engine.h" #include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/simple_jit.h" #include "paddle/cinn/backends/llvm/simple_jit.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
#include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/lang/packed_func.h"
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_module.h" #include "paddle/cinn/runtime/cuda/cuda_module.h"
...@@ -32,6 +33,38 @@ ...@@ -32,6 +33,38 @@
namespace cinn { namespace cinn {
namespace backends { namespace backends {
/**
* A class for dumping the code after compilation.
* Use FLAGS_cinn_dump_group_lowered_func to specify the directory to dump
* lowered function. Use FLAGS_cinn_dump_group_source_code to specify the
* directory to dump the source code. Use FLAGS_cinn_dump_group_ptx to specify
* the directory to dump ptx. Use FLAGS_cinn_dump_group_instruction to specify
* the directory to dump instruction.
*/
class CompilationInfoDumper {
public:
explicit CompilationInfoDumper(
const hlir::framework::ParallelCompiler::CompilationResult& info)
: info_(info) {
DumpLoweredFunc();
DumpSourceCode();
DumpPtxCode();
DumpInstruction();
}
private:
void DumpLoweredFunc();
void DumpSourceCode();
void DumpPtxCode();
void DumpInstruction();
void Dump(const std::string& base_path,
const int idx,
const std::string& file_name,
const std::string& content);
const hlir::framework::ParallelCompiler::CompilationResult& info_;
};
class SourceCodePrint { class SourceCodePrint {
public: public:
static SourceCodePrint* GetInstance() { static SourceCodePrint* GetInstance() {
......
...@@ -308,66 +308,40 @@ void Graph::VisualizeGroupedGraph( ...@@ -308,66 +308,40 @@ void Graph::VisualizeGroupedGraph(
return; return;
} }
int viz_id = viz_count_.fetch_add(1); // Dump debug info for each group
{ LOG(INFO) << "Dump graph debug info to: "
// create base Directory << FLAGS_cinn_fusion_groups_graphviz_dir;
viz_path_ =
utils::StringFormat("%s/fusion_groups_%d/",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
viz_id);
if (!MakeDirectory(viz_path_,
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG_IF(WARNING, viz_id == 0)
<< "Failed to make directory: \"" << viz_path_
<< "\", the CINN subgraph's fusion group information will not print.";
viz_path_.clear();
return;
}
LOG_IF(INFO, viz_id == 0) << "The CINN subgraph's fusion group information "
"will writing into path: \""
<< FLAGS_cinn_fusion_groups_graphviz_dir << "\"";
}
const auto& groups = RemoveAccCheckGroups(origin_groups); const auto& groups = RemoveAccCheckGroups(origin_groups);
{ const auto& group_dots = VisualizeGroups(groups, fetch_var_ids);
// save python test file for (int idx = 0; idx < groups.size(); ++idx) {
std::string py_test_path = viz_path_ + "/tests/"; // Create fusion_group_x folder
if (!MakeDirectory(py_test_path, auto group_path =
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { utils::StringFormat("%s/fusion_group_%d",
LOG_IF(WARNING, viz_id == 0) FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
<< "Failed to make directory: \"" << py_test_path idx);
<< "\", the CINN subgraph's python test file will not generate.";
py_test_path.clear();
}
if (!py_test_path.empty()) {
for (int i = 0; i < groups.size(); i++) {
WriteToFile(py_test_path + "test_group_" + std::to_string(i) + ".py",
GenerateGroupPythonCode(groups[i], fetch_var_ids));
}
}
}
Summary(groups, viz_path_);
WriteToFile(viz_path_ + "grouped_graph.dot",
VisualizeGraph(groups, fetch_var_ids));
{
// save each group's graphviz dot file
std::string group_path = viz_path_ + "/groups/";
if (!MakeDirectory(group_path, if (!MakeDirectory(group_path,
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) { S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG_IF(WARNING, viz_id == 0) LOG(WARNING) << "Failed to make directory: \"" << group_path
<< "Failed to make directory: \"" << group_path << "\", skip dump info for this group.";
<< "\", the CINN subgraph's group graphviz file will not save."; continue;
group_path.clear();
}
if (!group_path.empty()) {
const auto& group_dots = VisualizeGroups(groups, fetch_var_ids);
for (int i = 0; i < group_dots.size(); ++i) {
WriteToFile(GetFilePathForGroup(groups, i, group_path), group_dots[i]);
}
}
} }
// Create test_group_x.py
auto python_test_file =
utils::StringFormat("%s/test_group_%d.py", group_path.c_str(), idx);
WriteToFile(python_test_file,
GenerateGroupPythonCode(groups[idx], fetch_var_ids));
// Create x_group_name.dot
auto graph_group_file =
utils::StringFormat("%s/graph_group_%d.dot", group_path.c_str(), idx);
WriteToFile(graph_group_file, group_dots[idx]);
}
// Summary
Summary(groups, FLAGS_cinn_fusion_groups_graphviz_dir);
// Grouped graph
auto grouped_graph_file = utils::StringFormat(
"%s/grouped_graph.dot", FLAGS_cinn_fusion_groups_graphviz_dir.c_str());
WriteToFile(grouped_graph_file, VisualizeGraph(groups, fetch_var_ids));
} }
std::string Graph::VisualizeGraph( std::string Graph::VisualizeGraph(
...@@ -494,8 +468,6 @@ std::vector<std::string> Graph::VisualizeGroups( ...@@ -494,8 +468,6 @@ std::vector<std::string> Graph::VisualizeGroups(
return dot_vec; return dot_vec;
} }
std::atomic_size_t Graph::viz_count_{0};
std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() { std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
std::unordered_set<NodeData*> group_inputs; std::unordered_set<NodeData*> group_inputs;
...@@ -543,25 +515,6 @@ std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() { ...@@ -543,25 +515,6 @@ std::unordered_set<NodeData*> Graph::Group::GetOutputNodeDatas() {
return group_outputs; return group_outputs;
} }
void Graph::SaveSourceCode(const std::string& code) {
if (cinn::runtime::CheckStringFlagFalse(
FLAGS_cinn_fusion_groups_graphviz_dir) ||
viz_path_.empty()) {
return;
}
WriteToFile(viz_path_ + "source_code.cu", code);
}
void Graph::SavePTXCode(const std::string& ptx) {
if (cinn::runtime::CheckStringFlagFalse(
FLAGS_cinn_fusion_groups_graphviz_dir) ||
viz_path_.empty()) {
return;
}
WriteToFile(viz_path_ + "source_code.ptx", ptx);
}
} // namespace framework } // namespace framework
} // namespace hlir } // namespace hlir
} // namespace cinn } // namespace cinn
...@@ -283,9 +283,6 @@ class Graph : public cinn::common::Graph { ...@@ -283,9 +283,6 @@ class Graph : public cinn::common::Graph {
const std::vector<std::vector<Node*>>& groups, const std::vector<std::vector<Node*>>& groups,
const std::unordered_set<std::string>& fetch_var_ids = {}); const std::unordered_set<std::string>& fetch_var_ids = {});
void SaveSourceCode(const std::string& code);
void SavePTXCode(const std::string& ptx);
private: private:
std::string DebugGroupedGraph( std::string DebugGroupedGraph(
const std::vector<std::vector<Node*>>& groups, const std::vector<std::vector<Node*>>& groups,
...@@ -301,9 +298,6 @@ class Graph : public cinn::common::Graph { ...@@ -301,9 +298,6 @@ class Graph : public cinn::common::Graph {
std::vector<std::vector<Node*>> FusionGroupsToGroups(); std::vector<std::vector<Node*>> FusionGroupsToGroups();
std::string viz_path_;
static std::atomic_size_t viz_count_;
CINN_DISALLOW_COPY_AND_ASSIGN(Graph); CINN_DISALLOW_COPY_AND_ASSIGN(Graph);
}; };
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <unordered_set> #include <unordered_set>
#include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/context.h" #include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_lowering_util.h" #include "paddle/cinn/hlir/framework/op_lowering_util.h"
...@@ -77,21 +78,24 @@ GraphCompiler::CompilationResult GraphCompiler::Build( ...@@ -77,21 +78,24 @@ GraphCompiler::CompilationResult GraphCompiler::Build(
parallel_compiler_ = parallel_compiler_ =
std::make_shared<ParallelCompiler>(scope_, graph_, option, target_); std::make_shared<ParallelCompiler>(scope_, graph_, option, target_);
auto instructions = (*parallel_compiler_.get())(); auto result = (*parallel_compiler_.get())();
// Dump compilation result
backends::CompilationInfoDumper dumper(result);
if (options.remove_unused_variables) { if (options.remove_unused_variables) {
RemoveInvalidVariables(instructions); RemoveInvalidVariables(result.instructions);
} }
if (options.with_buffer_handle_instruction_inserted) { if (options.with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable"; VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(&instructions); InsertBufferHandlers(&result.instructions);
} }
VLOG(2) << "Compile With Parallel Compiler Done!"; VLOG(2) << "Compile With Parallel Compiler Done!";
GraphCompiler::CompilationResult compilation_result; GraphCompiler::CompilationResult compilation_result;
compilation_result.runtime_program.reset( compilation_result.runtime_program.reset(
new Program(scope_, std::move(instructions))); new Program(scope_, std::move(result.instructions)));
return compilation_result; return compilation_result;
} }
......
...@@ -365,6 +365,29 @@ void Instruction::Run( ...@@ -365,6 +365,29 @@ void Instruction::Run(
// } // }
} }
std::string Instruction::DumpInstruction() {
std::stringstream ss;
ss << "Instruction {" << std::endl;
for (size_t i = 0; i < fn_names_.size(); ++i) {
ss << " Function " << fn_names_[i] << ":" << std::endl;
ss << " function ptr: " << fn_ptrs_[i] << std::endl;
auto in_arg = in_args_[i];
std::sort(in_arg.begin(), in_arg.end());
for (auto& in_name : in_arg) {
ss << " input: " << in_name << std::endl;
}
auto out_arg = out_args_[i];
std::sort(out_arg.begin(), out_arg.end());
for (auto& out_name : out_arg) {
ss << " output: " << out_name << std::endl;
}
}
ss << "}" << std::endl;
return ss.str();
}
void Instruction::CheckResults( void Instruction::CheckResults(
const std::map<std::string, cinn_pod_value_t>* name2podargs, void* stream) { const std::map<std::string, cinn_pod_value_t>* name2podargs, void* stream) {
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
......
...@@ -132,6 +132,8 @@ class Instruction { ...@@ -132,6 +132,8 @@ class Instruction {
int size() { return fn_ptrs_.size(); } int size() { return fn_ptrs_.size(); }
std::string DumpInstruction();
std::vector<std::vector<std::string>> GetInArgs() { return in_args_; } std::vector<std::vector<std::string>> GetInArgs() { return in_args_; }
std::vector<std::vector<std::string>> GetOutArgs() { return out_args_; } std::vector<std::vector<std::string>> GetOutArgs() { return out_args_; }
void ClearInArgs() { in_args_.clear(); } void ClearInArgs() { in_args_.clear(); }
......
...@@ -30,15 +30,13 @@ ...@@ -30,15 +30,13 @@
#include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/runtime/flags.h"
DECLARE_int32(cinn_parallel_compile_size);
DECLARE_int32(cinn_parallel_compile_thread); DECLARE_int32(cinn_parallel_compile_thread);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
namespace framework { namespace framework {
static constexpr int DebugLogMaxLen = 30000;
std::vector<std::unique_ptr<Instruction>> ParallelCompiler::operator()() { ParallelCompiler::CompilationResult ParallelCompiler::operator()() {
if (graph_->fusion_groups.size() == 0) { if (graph_->fusion_groups.size() == 0) {
hlir::framework::ApplyPasses(graph_.get(), {"BuildNonFusedGroupsPass"}); hlir::framework::ApplyPasses(graph_.get(), {"BuildNonFusedGroupsPass"});
} }
...@@ -50,48 +48,31 @@ std::vector<std::unique_ptr<Instruction>> ParallelCompiler::operator()() { ...@@ -50,48 +48,31 @@ std::vector<std::unique_ptr<Instruction>> ParallelCompiler::operator()() {
return MergeResult(); return MergeResult();
} }
OpPatternKind GetOpKind(const framework::Node* node) {
auto& op_pattern_dict =
framework::Operator::GetAttrs<OpPatternKind>("OpPattern");
CHECK(op_pattern_dict.Find(node->op()))
<< "Don't find the pattern of op : " << node->id();
auto kind = op_pattern_dict[node->op()];
if (kind == framework::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
if (node->op()->name != "broadcast_to") {
return framework::kElementWise;
}
}
return kind;
}
void ParallelCompiler::SplitTask() { void ParallelCompiler::SplitTask() {
CHECK(graph_->fusion_groups.size()); CHECK(graph_->fusion_groups.size());
CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() || CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() ||
option_.lowered_funcs.size() == 0); option_.lowered_funcs.size() == 0);
// split task // Assign fusion_group to each task.
int max_task_num = FLAGS_cinn_parallel_compile_thread > 0 // The maximum number of tasks is determined by the number of threads.
// Fusion_group is assigned to tasks in order and continuous.
int fusion_group_size = graph_->fusion_groups.size();
int thread_size = FLAGS_cinn_parallel_compile_thread > 0
? FLAGS_cinn_parallel_compile_thread ? FLAGS_cinn_parallel_compile_thread
: graph_->fusion_groups.size(); : 1;
int group_per_task =
int group_per_task = graph_->fusion_groups.size(); (graph_->fusion_groups.size() + thread_size - 1) / thread_size;
if (max_task_num > 1) {
group_per_task = FLAGS_cinn_parallel_compile_size > 0
? FLAGS_cinn_parallel_compile_size
: ((graph_->fusion_groups.size() + max_task_num - 1) /
max_task_num);
}
for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) { for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) {
tasks_.emplace_back(this, scope_, graph_, option_, target_); Task task(this, scope_, graph_, option_, target_);
task.start_gidx = idx;
task.stop_gidx =
(idx + group_per_task > fusion_group_size ? fusion_group_size
: idx + group_per_task);
tasks_.emplace_back(std::move(task));
} }
VLOG(2) << "Split task to " << tasks_.size() << " sub-task!"; VLOG(2) << "Split task to " << tasks_.size() << " sub-task!";
} }
void RunTask(ParallelCompiler::Task* task) { void ParallelCompiler::RunTask(ParallelCompiler::Task* task) {
VLOG(2) << "Stark run sub-task, Thread Id : " << std::this_thread::get_id(); VLOG(2) << "Stark run sub-task, Thread Id : " << std::this_thread::get_id();
VLOG(4) << "Start Lowering"; VLOG(4) << "Start Lowering";
task->Lowering(); task->Lowering();
...@@ -106,7 +87,7 @@ void ParallelCompiler::LaunchTask() { ...@@ -106,7 +87,7 @@ void ParallelCompiler::LaunchTask() {
// start sub-task. // start sub-task.
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (int idx = 1; idx < tasks_.size(); ++idx) { for (int idx = 1; idx < tasks_.size(); ++idx) {
threads.emplace_back(RunTask, &tasks_[idx]); threads.emplace_back(&ParallelCompiler::RunTask, this, &tasks_[idx]);
} }
RunTask(&tasks_[0]); RunTask(&tasks_[0]);
...@@ -116,11 +97,20 @@ void ParallelCompiler::LaunchTask() { ...@@ -116,11 +97,20 @@ void ParallelCompiler::LaunchTask() {
} }
} }
std::vector<std::unique_ptr<Instruction>> ParallelCompiler::MergeResult() { ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() {
std::vector<std::unique_ptr<Instruction>> res(graph_->fusion_groups.size()); ParallelCompiler::CompilationResult res;
for (auto& task : tasks_) { for (auto& task : tasks_) {
for (int idx = 0; idx < task.gidx.size(); ++idx) { for (auto& lowered_func : task.lowered_funcs) {
res[task.gidx[idx]] = std::move(task.instructions[idx]); res.lowered_funcs.emplace_back(lowered_func);
}
for (auto& source_code : task.source_codes) {
res.source_codes.emplace_back(source_code);
}
for (auto& source_ptx : task.source_ptxs) {
res.source_ptxs.emplace_back(source_ptx);
}
for (auto& instruction : task.instructions) {
res.instructions.emplace_back(std::move(instruction));
} }
} }
return std::move(res); return std::move(res);
...@@ -138,13 +128,7 @@ void ParallelCompiler::Task::Lowering() { ...@@ -138,13 +128,7 @@ void ParallelCompiler::Task::Lowering() {
"infershape"); "infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, target); OpLowerer op_lowerer(dtype_dict, shape_dict, target);
while (true) { for (int idx = start_gidx; idx < stop_gidx; ++idx) {
int idx = compiler->GetGroupIdx();
if (idx < 0) {
break;
}
gidx.push_back(idx);
if (options.lowered_funcs.size()) { if (options.lowered_funcs.size()) {
lowered_funcs.push_back(options.lowered_funcs[idx]); lowered_funcs.push_back(options.lowered_funcs[idx]);
continue; continue;
...@@ -154,16 +138,15 @@ void ParallelCompiler::Task::Lowering() { ...@@ -154,16 +138,15 @@ void ParallelCompiler::Task::Lowering() {
<< std::this_thread::get_id() << " :\n" << std::this_thread::get_id() << " :\n"
<< "Group " << idx << " {\n" << "Group " << idx << " {\n"
<< graph->DebugGroupedGraph(group->CollectNodes()) << "}\n"; << graph->DebugGroupedGraph(group->CollectNodes()) << "}\n";
lowered_funcs.emplace_back(std::move(op_lowerer.Lower(group))); auto lowered_group = op_lowerer.Lower(group);
CHECK_EQ(lowered_funcs.back().size(), 1) CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
<< "Lowerd Function Is Not Equal 1!"; lowered_funcs.emplace_back(std::move(lowered_group));
} }
} }
void ParallelCompiler::Task::CodegenAndJit() { void ParallelCompiler::Task::CodegenAndJit() {
VLOG(2) << "Start Codegen and JIT with Group [" VLOG(2) << "Start Codegen and JIT with Group [" << start_gidx << "-"
<< cinn::utils::Join(this->gidx, ", ") << "] at " << stop_gidx << ") at thread" << std::this_thread::get_id();
<< std::this_thread::get_id();
// build module // build module
ir::Module::Builder builder(common::UniqName("module"), target); ir::Module::Builder builder(common::UniqName("module"), target);
for (auto& func : lowered_funcs) { for (auto& func : lowered_funcs) {
...@@ -172,7 +155,6 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -172,7 +155,6 @@ void ParallelCompiler::Task::CodegenAndJit() {
} }
auto ir_module = builder.Build(); auto ir_module = builder.Build();
// codegen compile
if (target == common::DefaultNVGPUTarget()) { if (target == common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
auto splited_module = backends::SplitCudaAndHostModule(ir_module); auto splited_module = backends::SplitCudaAndHostModule(ir_module);
...@@ -185,14 +167,15 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -185,14 +167,15 @@ void ParallelCompiler::Task::CodegenAndJit() {
auto cuda_c = codegen.Compile(dmodule); auto cuda_c = codegen.Compile(dmodule);
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule; << dmodule;
source_codes.emplace_back(cuda_c);
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
graph->SaveSourceCode(cuda_c);
using runtime::cuda::CUDAModule; using runtime::cuda::CUDAModule;
backends::nvrtc::Compiler compiler; backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c); auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
source_ptxs.emplace_back(ptx);
// load cumodule // load cumodule
cumodule.reset(new CUDAModule(ptx, cumodule.reset(new CUDAModule(ptx,
compiler.compile_to_cubin() compiler.compile_to_cubin()
...@@ -218,7 +201,7 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -218,7 +201,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
void ParallelCompiler::Task::BuildInstruction() { void ParallelCompiler::Task::BuildInstruction() {
// create instruction. // create instruction.
for (int idx : gidx) { for (int idx = start_gidx; idx < stop_gidx; ++idx) {
VLOG(2) << "Start BuildInstruction of Group " << idx << " at " VLOG(2) << "Start BuildInstruction of Group " << idx << " at "
<< std::this_thread::get_id(); << std::this_thread::get_id();
auto& group = graph->fusion_groups[idx]; auto& group = graph->fusion_groups[idx];
...@@ -240,15 +223,6 @@ void ParallelCompiler::Task::BuildInstruction() { ...@@ -240,15 +223,6 @@ void ParallelCompiler::Task::BuildInstruction() {
} }
} }
int ParallelCompiler::GetGroupIdx() {
std::lock_guard<std::mutex> lock(mtx_);
if (index < graph_->fusion_groups.size()) {
return index++;
} else {
return -1;
}
}
} // namespace framework } // namespace framework
} // namespace hlir } // namespace hlir
} // namespace cinn } // namespace cinn
...@@ -35,23 +35,18 @@ class ParallelCompiler { ...@@ -35,23 +35,18 @@ class ParallelCompiler {
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs; std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
}; };
public: struct CompilationResult {
explicit ParallelCompiler(std::shared_ptr<Scope>& scope, // NOLINT // Lower result
std::shared_ptr<Graph>& graph, // NOLINT std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
const CompileOptions& option, // Host/CUDA codegen result
const common::Target& target) std::vector<std::string> source_codes;
: scope_(scope), graph_(graph), option_(option), target_(target) {} // CUDA ptx result
~ParallelCompiler() {} std::vector<std::string> source_ptxs;
std::vector<std::unique_ptr<Instruction>> operator()(); // Instruction result
std::vector<std::unique_ptr<Instruction>> instructions;
private: };
void SplitTask();
void LaunchTask();
std::vector<std::unique_ptr<Instruction>> MergeResult();
public:
struct Task { struct Task {
public:
Task(ParallelCompiler* p, Task(ParallelCompiler* p,
std::shared_ptr<Scope>& s, // NOLINT std::shared_ptr<Scope>& s, // NOLINT
std::shared_ptr<Graph>& g, // NOLINT std::shared_ptr<Graph>& g, // NOLINT
...@@ -62,30 +57,40 @@ class ParallelCompiler { ...@@ -62,30 +57,40 @@ class ParallelCompiler {
void CodegenAndJit(); void CodegenAndJit();
void BuildInstruction(); void BuildInstruction();
public:
const Target target; const Target target;
ParallelCompiler* compiler; ParallelCompiler* compiler;
std::shared_ptr<Scope> scope; std::shared_ptr<Scope> scope;
std::shared_ptr<Graph> graph; std::shared_ptr<Graph> graph;
const CompileOptions& options; const CompileOptions& options;
std::vector<int> gidx; int start_gidx;
int stop_gidx;
std::vector<std::unique_ptr<Instruction>> instructions; std::vector<std::unique_ptr<Instruction>> instructions;
std::vector<std::vector<ir::LoweredFunc>> lowered_funcs; std::vector<std::vector<ir::LoweredFunc>> lowered_funcs;
std::vector<std::string> source_codes;
std::vector<std::string> source_ptxs;
public:
std::unique_ptr<backends::ExecutionEngine> engine; std::unique_ptr<backends::ExecutionEngine> engine;
#ifdef CINN_WITH_CUDA #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cumodule; std::unique_ptr<runtime::cuda::CUDAModule> cumodule;
#endif #endif
}; };
std::vector<Task> tasks_;
int GetGroupIdx(); explicit ParallelCompiler(std::shared_ptr<Scope>& scope, // NOLINT
std::shared_ptr<Graph>& graph, // NOLINT
const CompileOptions& option,
const common::Target& target)
: scope_(scope), graph_(graph), option_(option), target_(target) {}
~ParallelCompiler() {}
CompilationResult operator()();
private: private:
int index{0}; void SplitTask();
std::mutex mtx_; void LaunchTask();
void RunTask(Task* task);
CompilationResult MergeResult();
std::vector<Task> tasks_;
const common::Target target_; const common::Target target_;
const CompileOptions& option_; const CompileOptions& option_;
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
......
...@@ -148,66 +148,30 @@ bool PassPrinter::End() { ...@@ -148,66 +148,30 @@ bool PassPrinter::End() {
} }
bool MakeDirectory(const std::string& dirname, mode_t mode) { bool MakeDirectory(const std::string& dirname, mode_t mode) {
auto len = dirname.length(); struct stat st;
std::vector<char> dir_path(len + 1, '\0'); std::string path;
strncpy(dir_path.data(), dirname.c_str(), len); for (int i = 0; i < dirname.size(); ++i) {
char* path = dir_path.data(); path.push_back(dirname[i]);
for (char* p = strchr(path + 1, '/'); p; p = strchr(p + 1, '/')) { if (!(dirname[i] == '/' || i + 1 == dirname.size())) {
*p = '\0'; continue;
if (mkdir(path, mode) == -1) { }
if (errno != EEXIST) { if (stat(path.c_str(), &st) == 0) {
*p = '/'; if (S_ISDIR(st.st_mode)) {
continue;
} else {
LOG(WARNING) << path << " is not a directory, please check your path.";
return false; return false;
} }
} } else {
*p = '/'; if (mkdir(path.c_str(), mode) == 0) {
} continue;
return true; } else {
} LOG(WARNING) << "Make directory fail: " << path;
return false;
std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
const int group_id,
const std::string& viz_path) {
std::string filename = "";
for (auto* node : groups[group_id]) {
filename += "_" + node->id();
}
int max_len = 50;
std::string simplified_filename = filename;
if (filename.size() > max_len) {
static std::unordered_map<std::string, std::string> funcname_map = {
{"const_scalar", "scalar"},
{"fill_constant", "fill"},
{"identity", "copy"},
{"broadcast_to", "broadcast"},
{"elementwise_add", "add"},
{"subtract", "sub"},
{"elementwise_mul", "mul"},
{"divide", "div"},
{"reduce_sum", "reduce"},
{"reduce_prod", "reduce"},
{"reduce_max", "reduce"},
{"reduce_min", "reduce"}};
for (auto& item : funcname_map) {
size_t index = 0;
while (true) {
index = simplified_filename.find(item.first, index);
if (index == std::string::npos) {
break;
}
simplified_filename.replace(index, item.first.size(), item.second);
index += item.second.size();
} }
} }
} }
return true;
int width = std::to_string(groups.size()).size();
std::stringstream ss;
ss << viz_path;
ss << std::setw(width) << std::setfill('0') << group_id;
ss << simplified_filename.substr(0, 50) << ".dot";
return ss.str();
} }
std::string GenNodeDataLabel( std::string GenNodeDataLabel(
...@@ -313,7 +277,7 @@ void Summary(const std::vector<std::vector<Node*>>& groups, ...@@ -313,7 +277,7 @@ void Summary(const std::vector<std::vector<Node*>>& groups,
<< "Numbers\n"; << "Numbers\n";
print_table(fusion_group_detail); print_table(fusion_group_detail);
std::string filepath = viz_path + "summary.txt"; std::string filepath = viz_path + "/summary.txt";
WriteToFile(filepath, ss.str()); WriteToFile(filepath, ss.str());
} }
......
...@@ -133,10 +133,6 @@ inline std::vector<utils::DotAttr> GetGroupAttrs(size_t group_size) { ...@@ -133,10 +133,6 @@ inline std::vector<utils::DotAttr> GetGroupAttrs(size_t group_size) {
bool MakeDirectory(const std::string& dirname, mode_t mode); bool MakeDirectory(const std::string& dirname, mode_t mode);
std::string GetFilePathForGroup(const std::vector<std::vector<Node*>>& groups,
const int group_id,
const std::string& viz_path);
std::string GenNodeDataLabel( std::string GenNodeDataLabel(
const NodeData* node, const NodeData* node,
const absl::flat_hash_map<std::string, shape_t>& shape_dict, const absl::flat_hash_map<std::string, shape_t>& shape_dict,
......
...@@ -44,13 +44,8 @@ DEFINE_string(cinn_nvcc_cmd_path, ...@@ -44,13 +44,8 @@ DEFINE_string(cinn_nvcc_cmd_path,
StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"), StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"),
"Setting nvcc default path!"); "Setting nvcc default path!");
DEFINE_int32(cinn_parallel_compile_size,
Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16),
"When use parallel compile, set the number of group compiled by "
"each thread.");
DEFINE_int32(cinn_parallel_compile_thread, DEFINE_int32(cinn_parallel_compile_thread,
Int32FromEnv("FLAGS_cinn_parallel_compile_thread", -1), Int32FromEnv("FLAGS_cinn_parallel_compile_thread", 16),
"How much thread the parallel compile used."); "How much thread the parallel compile used.");
DEFINE_bool(cinn_use_op_fusion, DEFINE_bool(cinn_use_op_fusion,
...@@ -131,6 +126,26 @@ DEFINE_string(cinn_source_code_save_path, ...@@ -131,6 +126,26 @@ DEFINE_string(cinn_source_code_save_path,
"Specify the directory path of generated source code, which is " "Specify the directory path of generated source code, which is "
"used for debug."); "used for debug.");
DEFINE_string(cinn_dump_group_lowered_func,
StringFromEnv("FLAGS_cinn_dump_group_lowered_func", ""),
"Specify the path for dump lowered functions by group, which is "
"used for debug.");
DEFINE_string(
cinn_dump_group_source_code,
StringFromEnv("FLAGS_cinn_dump_group_source_code", ""),
"Specify the path for dump source code by group, which is used for debug.");
DEFINE_string(
cinn_dump_group_ptx,
StringFromEnv("FLAGS_cinn_dump_group_ptx", ""),
"Specify the path for dump ptx by group, which is used for debug.");
DEFINE_string(
cinn_dump_group_instruction,
StringFromEnv("FLAGS_cinn_dump_group_instruction", ""),
"Specify the path for dump instruction by group, which is used for debug.");
DEFINE_string(cinn_pass_visualize_dir, DEFINE_string(cinn_pass_visualize_dir,
StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""), StringFromEnv("FLAGS_cinn_pass_visualize_dir", ""),
"Specify the directory path of pass visualize file of graph, " "Specify the directory path of pass visualize file of graph, "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册