From 04b6035d171ccec1126d59b4b34c46e4847333f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 15 Aug 2023 17:22:42 +0800 Subject: [PATCH] modify default parallel compile thread (#56282) --- paddle/cinn/hlir/framework/parallel_compiler.cc | 12 ++++++------ paddle/cinn/runtime/flags.cc | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/framework/parallel_compiler.cc b/paddle/cinn/hlir/framework/parallel_compiler.cc index 1b2cbca9e05..ac3ab03e588 100644 --- a/paddle/cinn/hlir/framework/parallel_compiler.cc +++ b/paddle/cinn/hlir/framework/parallel_compiler.cc @@ -92,9 +92,7 @@ void ParallelCompiler::LaunchTask() { RunTask(&tasks_[0]); // syncthreads. - for (auto& worker : threads) { - worker.join(); - } + for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join)); } ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() { @@ -113,7 +111,7 @@ ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() { res.instructions.emplace_back(std::move(instruction)); } } - return std::move(res); + return res; } void ParallelCompiler::Task::Lowering() { @@ -167,7 +165,6 @@ void ParallelCompiler::Task::CodegenAndJit() { auto cuda_c = codegen.Compile(dmodule); CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n" << dmodule; - source_codes.emplace_back(cuda_c); cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c); @@ -175,13 +172,16 @@ void ParallelCompiler::Task::CodegenAndJit() { backends::nvrtc::Compiler compiler; auto ptx = compiler(cuda_c); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; - source_ptxs.emplace_back(ptx); + // load cumodule cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + source_codes.emplace_back(std::move(cuda_c)); + source_ptxs.emplace_back(std::move(ptx)); + // register kernel backends::RuntimeSymbols symbols; for (auto& fn : dmodule.functions()) { diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 98d1d41bd16..91bf9d1f47c 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -19,7 +19,7 @@ #include #include #include - +#include #include #include "paddle/cinn/common/target.h" @@ -46,7 +46,8 @@ DEFINE_string(cinn_nvcc_cmd_path, "Setting nvcc default path!"); DEFINE_int32(cinn_parallel_compile_thread, - Int32FromEnv("FLAGS_cinn_parallel_compile_thread", 16), + Int32FromEnv("FLAGS_cinn_parallel_compile_thread", + (std::thread::hardware_concurrency() >> 1)), "How much thread the parallel compile used."); DEFINE_bool(cinn_use_op_fusion, -- GitLab