未验证 提交 04b6035d 编写于 作者: 傅剑寒 提交者: GitHub

modify default parallel compile thread (#56282)

上级 47686692
...@@ -92,9 +92,7 @@ void ParallelCompiler::LaunchTask() { ...@@ -92,9 +92,7 @@ void ParallelCompiler::LaunchTask() {
RunTask(&tasks_[0]); RunTask(&tasks_[0]);
// syncthreads. // syncthreads.
for (auto& worker : threads) { for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join));
worker.join();
}
} }
ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() { ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() {
...@@ -113,7 +111,7 @@ ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() { ...@@ -113,7 +111,7 @@ ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() {
res.instructions.emplace_back(std::move(instruction)); res.instructions.emplace_back(std::move(instruction));
} }
} }
return std::move(res); return res;
} }
void ParallelCompiler::Task::Lowering() { void ParallelCompiler::Task::Lowering() {
...@@ -167,7 +165,6 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -167,7 +165,6 @@ void ParallelCompiler::Task::CodegenAndJit() {
auto cuda_c = codegen.Compile(dmodule); auto cuda_c = codegen.Compile(dmodule);
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule; << dmodule;
source_codes.emplace_back(cuda_c);
cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
...@@ -175,13 +172,16 @@ void ParallelCompiler::Task::CodegenAndJit() { ...@@ -175,13 +172,16 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler; backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c); auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
source_ptxs.emplace_back(ptx);
// load cumodule // load cumodule
cumodule.reset(new CUDAModule(ptx, cumodule.reset(new CUDAModule(ptx,
compiler.compile_to_cubin() compiler.compile_to_cubin()
? CUDAModule::Kind::CUBIN ? CUDAModule::Kind::CUBIN
: CUDAModule::Kind::PTX)); : CUDAModule::Kind::PTX));
source_codes.emplace_back(std::move(cuda_c));
source_ptxs.emplace_back(std::move(ptx));
// register kernel // register kernel
backends::RuntimeSymbols symbols; backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) { for (auto& fn : dmodule.functions()) {
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <thread>
#include <unordered_set> #include <unordered_set>
#include "paddle/cinn/common/target.h" #include "paddle/cinn/common/target.h"
...@@ -46,7 +46,8 @@ DEFINE_string(cinn_nvcc_cmd_path, ...@@ -46,7 +46,8 @@ DEFINE_string(cinn_nvcc_cmd_path,
"Setting nvcc default path!"); "Setting nvcc default path!");
DEFINE_int32(cinn_parallel_compile_thread, DEFINE_int32(cinn_parallel_compile_thread,
Int32FromEnv("FLAGS_cinn_parallel_compile_thread", 16), Int32FromEnv("FLAGS_cinn_parallel_compile_thread",
(std::thread::hardware_concurrency() >> 1)),
"How much thread the parallel compile used."); "How much thread the parallel compile used.");
DEFINE_bool(cinn_use_op_fusion, DEFINE_bool(cinn_use_op_fusion,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册