未验证 提交 dd3d45de 编写于 作者: L Leo Chen 提交者: GitHub

reduce time cost on atomic in interpretercore (#46688)

* reduce time cost on atomic in interpretercore

* clear code of PrepareAtomic in interpretercore

* refine threadpool cache
上级 c333af2f
...@@ -326,7 +326,9 @@ inline void RunProgramAPI( ...@@ -326,7 +326,9 @@ inline void RunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance(); paddle::framework::InterpreterCoreInfoCache::Instance();
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) { if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) {
VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; VLOG(2) << "No interpretercore cahce, so create a new interpretercore "
"for program: "
<< program_id;
// Step 1. share input_vars & parameters into scope // Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope); details::ShareTensorsIntoScope(params, global_inner_scope);
...@@ -545,19 +547,14 @@ inline void RunProgramGradAPI( ...@@ -545,19 +547,14 @@ inline void RunProgramGradAPI(
// share threadpool // share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly // NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core. // after the related fwd_interpreter_core.
PADDLE_ENFORCE_EQ( if (interpretercore_info_cache.Has(program_id, false)) {
interpretercore_info_cache.Has(program_id, false),
true,
paddle::platform::errors::NotFound(
"The forward interpretercore of program %d is not found",
program_id));
auto fwd_interpreter_core = auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false) interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_; .core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get()
<< interpreter_core.get(); << " to " << interpreter_core.get();
}
// get all eager gc vars // get all eager gc vars
std::set<std::string> skip_eager_delete_vars; std::set<std::string> skip_eager_delete_vars;
// all out_vars are skip_eager_var // all out_vars are skip_eager_var
......
...@@ -295,6 +295,8 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache( ...@@ -295,6 +295,8 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
auto &interpretercore_info_cache = auto &interpretercore_info_cache =
framework::InterpreterCoreInfoCache::Instance(); framework::InterpreterCoreInfoCache::Instance();
if (interpretercore_info_cache.Size() > 4u /* max_cached_size*/) { if (interpretercore_info_cache.Size() > 4u /* max_cached_size*/) {
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
"all cache!";
interpretercore_info_cache.Finalize(); interpretercore_info_cache.Finalize();
} }
auto core = std::make_shared<InterpreterCore>( auto core = std::make_shared<InterpreterCore>(
......
...@@ -21,11 +21,13 @@ namespace framework { ...@@ -21,11 +21,13 @@ namespace framework {
namespace interpreter { namespace interpreter {
void WaitEvent(const Instruction& instruction, const platform::Place& place) { void WaitEvent(const Instruction& instruction, const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) {
return;
}
VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type(); VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type();
for (auto& event_iter : instruction.InputEvents()) { for (const auto& event_iter : instruction.InputEvents()) {
platform::RecordEvent record( platform::RecordEvent record(
"WaitStreamEvent", platform::TracerEventType::UserDefined, 10); "WaitStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "wait var_id: " << event_iter.var_id_ VLOG(3) << "wait var_id: " << event_iter.var_id_
...@@ -37,9 +39,11 @@ void WaitEvent(const Instruction& instruction, const platform::Place& place) { ...@@ -37,9 +39,11 @@ void WaitEvent(const Instruction& instruction, const platform::Place& place) {
void RecordEvent(const Instruction& instruction, const platform::Place& place) { void RecordEvent(const Instruction& instruction, const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing. // If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return; if (platform::is_cpu_place(place)) {
return;
}
for (auto& event : instruction.OutputEvents()) { for (const auto& event : instruction.OutputEvents()) {
platform::RecordEvent record( platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10); "RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "Record event in out_var_id: " << event.var_id_; VLOG(3) << "Record event in out_var_id: " << event.var_id_;
......
...@@ -32,32 +32,19 @@ namespace interpreter { ...@@ -32,32 +32,19 @@ namespace interpreter {
static constexpr size_t kHostNumThreads = 4; static constexpr size_t kHostNumThreads = 4;
static constexpr size_t kDeviceNumThreads = 1; static constexpr size_t kDeviceNumThreads = 1;
static constexpr size_t kNumGcThreads = 1; static constexpr size_t kNumGcThreads = 1;
static constexpr size_t kNumPrepareThreads = 0;
static constexpr size_t kMinOpNumForAsyncPrepare = 1000;
// By default, one interpretercore contains: // By default, one interpretercore contains:
// 1-size thread pool for device kernel launch (or 0 for cpu execution), // 1-size thread pool for device kernel launch (or 0 for cpu execution),
// 1-size thread pool for host kernel launch (or more if the system contains // 1-size thread pool for host kernel launch (or more if the system contains
// enough processors). // enough processors).
// And it may contain:
// 1-size thread pool for gc if it is can not use FastGC,
// 1-size thread pool for preparation if the program contains two many ops
// (1000+).
// Note that the purpose of the config is to limit the total 'possible' // Note that the purpose of the config is to limit the total 'possible'
// threads introduced by interpretercore to avoid hurting performance. // threads introduced by interpretercore to avoid hurting performance.
inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place, inline std::tuple<int, int> GetThreadPoolConfig(const phi::Place& place,
size_t op_num) { size_t op_num) {
int num_device_threads = kDeviceNumThreads, int num_device_threads = kDeviceNumThreads,
num_host_threads = kHostNumThreads, num_host_threads = kHostNumThreads;
num_prepare_threads = kNumPrepareThreads;
if (op_num > kMinOpNumForAsyncPrepare) {
num_prepare_threads = 1;
}
int device_count = 0, processor_count = 0; int device_count = 0, processor_count = 0;
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
...@@ -109,7 +96,7 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place, ...@@ -109,7 +96,7 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place,
if (device_count) { if (device_count) {
auto num = processor_count / device_count / 2 - auto num = processor_count / device_count / 2 -
(kNumGcThreads + kNumPrepareThreads + num_device_threads); (kNumGcThreads + num_device_threads);
num_host_threads = num_host_threads =
num > 0 ? (num > kHostNumThreads ? kHostNumThreads : num) : 1; num > 0 ? (num > kHostNumThreads ? kHostNumThreads : num) : 1;
} }
...@@ -126,14 +113,13 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place, ...@@ -126,14 +113,13 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place,
<< ", device_count:" << device_count << ", device_count:" << device_count
<< ", serial_run:" << FLAGS_new_executor_serial_run << ", serial_run:" << FLAGS_new_executor_serial_run
<< ", num_host_threads:" << num_host_threads << ", num_host_threads:" << num_host_threads
<< ", num_device_threads:" << num_device_threads << ", num_device_threads:" << num_device_threads;
<< ", num_prepare_threads:" << num_prepare_threads;
return std::make_tuple( return std::make_tuple(num_host_threads, num_device_threads);
num_host_threads, num_device_threads, num_prepare_threads);
} }
ExecutionConfig::ExecutionConfig(const phi::Place& place, size_t op_num) { ExecutionConfig::ExecutionConfig(const phi::Place& place, size_t op_num) {
std::tie(host_num_threads, deivce_num_threads, prepare_num_threads) = std::tie(host_num_threads, deivce_num_threads) =
GetThreadPoolConfig(place, op_num); GetThreadPoolConfig(place, op_num);
} }
...@@ -143,7 +129,6 @@ void ExecutionConfig::Log(int log_level) { ...@@ -143,7 +129,6 @@ void ExecutionConfig::Log(int log_level) {
VLOG(log_level) << "create_local_scope = " << create_local_scope; VLOG(log_level) << "create_local_scope = " << create_local_scope;
VLOG(log_level) << "host_num_threads = " << host_num_threads; VLOG(log_level) << "host_num_threads = " << host_num_threads;
VLOG(log_level) << "deivce_num_threads = " << deivce_num_threads; VLOG(log_level) << "deivce_num_threads = " << deivce_num_threads;
VLOG(log_level) << "prepare_num_threads = " << prepare_num_threads;
VLOG(log_level) << "skip_gc_vars = "; VLOG(log_level) << "skip_gc_vars = ";
for (const std::string& var : skip_gc_vars) { for (const std::string& var : skip_gc_vars) {
VLOG(log_level) << var; VLOG(log_level) << var;
......
...@@ -29,7 +29,6 @@ struct ExecutionConfig { ...@@ -29,7 +29,6 @@ struct ExecutionConfig {
size_t host_num_threads; size_t host_num_threads;
size_t deivce_num_threads; size_t deivce_num_threads;
size_t prepare_num_threads;
std::set<std::string> skip_gc_vars; std::set<std::string> skip_gc_vars;
......
...@@ -321,7 +321,6 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() { ...@@ -321,7 +321,6 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>( async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
execution_config_.host_num_threads, execution_config_.host_num_threads,
execution_config_.deivce_num_threads, execution_config_.deivce_num_threads,
execution_config_.prepare_num_threads,
&main_thread_blocker_); &main_thread_blocker_);
} }
return async_work_queue_; return async_work_queue_;
...@@ -601,17 +600,13 @@ void InterpreterCore::Convert( ...@@ -601,17 +600,13 @@ void InterpreterCore::Convert(
BuildInplace(); BuildInplace();
} }
// prepare for the first time. for (auto& dep : dependecy_count_) {
std::promise<std::unique_ptr<AtomicVectorSizeT>> deps_promise = deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
std::promise<std::unique_ptr<AtomicVectorSizeT>>(); }
atomic_deps_ = deps_promise.get_future(); for (size_t i = 0; i < vec_meta_info.size(); ++i) {
deps_promise.set_value(interpreter::PrepareAtomicDeps(dependecy_count_)); refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
vec_meta_info[i].var_ref_count_, var_scope_.VarRef(i)));
std::promise<std::unique_ptr<AtomicVectorSizeT>> var_ref_promise = }
std::promise<std::unique_ptr<AtomicVectorSizeT>>();
atomic_var_ref_ = var_ref_promise.get_future();
var_ref_promise.set_value(
interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo()));
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
...@@ -804,45 +799,19 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -804,45 +799,19 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
interpreter::ResetAtomicGuard guard(&deps_, &refs_);
unfinished_op_numer_ = vec_instr.size(); unfinished_op_numer_ = vec_instr.size();
if (unfinished_op_numer_ == 0) { if (unfinished_op_numer_ == 0) {
VLOG(4) << "No op to run, return"; VLOG(4) << "No op to run, return";
return; return;
} }
platform::RecordEvent record_prepare(
"PrepareAtomic", platform::TracerEventType::UserDefined, 1);
std::unique_ptr<std::vector<std::atomic<size_t>>> atomic_deps = nullptr;
std::unique_ptr<std::vector<std::atomic<size_t>>> atomic_var_ref = nullptr;
if (async_work_queue_->QueueNumThreads(kPrepareWorkQueueIdx)) {
// NOTE(zhiqiu): get the prepared deps from std::future, and async prepare
// those for the next step
atomic_deps = atomic_deps_.get();
atomic_var_ref = atomic_var_ref_.get();
atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_);
atomic_var_ref_ =
async_work_queue_->PrepareAtomicVarRef(var_scope_.VecMetaInfo());
} else {
atomic_deps = interpreter::PrepareAtomicDeps(dependecy_count_);
atomic_var_ref = interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo());
}
record_prepare.End();
exception_holder_.Clear(); exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) { for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this, [this, i] { RunInstructionAsync(i); });
i,
atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()] {
RunInstructionAsync(
i, atomic_deps, atomic_var_ref);
});
} }
} }
...@@ -869,19 +838,15 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -869,19 +838,15 @@ void InterpreterCore::ExecuteInstructionList(
} }
void InterpreterCore::RunNextInstructions( void InterpreterCore::RunNextInstructions(
const Instruction& instr, const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
platform::RecordEvent record( platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10); "RunNextInstructions", platform::TracerEventType::UserDefined, 10);
VLOG(4) << "atomic 1:" << atomic_deps;
auto& next_instr = instr.NextInstructions(); auto& next_instr = instr.NextInstructions();
auto IsReady = [atomic_deps](size_t next_id) { auto IsReady = [this](size_t next_id) {
VLOG(4) << "atomic:" << atomic_deps << " op_id: " << next_id VLOG(4) << "op_id: " << next_id
<< ", remain deps: " << (*atomic_deps)[next_id]; << ", remain deps: " << deps_[next_id]->DynamicDep();
return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1; return deps_[next_id]->CheckAndDecrease();
}; };
if (instr.KernelType() == OpFuncType::kQueueAsync) { if (instr.KernelType() == OpFuncType::kQueueAsync) {
...@@ -890,9 +855,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -890,9 +855,7 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_->AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref]() { [this, next_id]() { RunInstructionAsync(next_id); });
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
} }
} }
// keep all async_ops running in current thread // keep all async_ops running in current thread
...@@ -912,9 +875,7 @@ void InterpreterCore::RunNextInstructions( ...@@ -912,9 +875,7 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) { if (IsReady(next_id)) {
async_work_queue_->AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref] { [this, next_id] { RunInstructionAsync(next_id); });
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
} }
} }
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(), auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
...@@ -930,19 +891,14 @@ void InterpreterCore::RunNextInstructions( ...@@ -930,19 +891,14 @@ void InterpreterCore::RunNextInstructions(
// move rest ops into other threads // move rest ops into other threads
async_work_queue_->AddTask( async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(), vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref] { [this, next_id] { RunInstructionAsync(next_id); });
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
} }
} }
if (first_op != -1) reserved_next_ops->push(first_op); if (first_op != -1) reserved_next_ops->push(first_op);
} }
} }
void InterpreterCore::RunInstructionAsync( void InterpreterCore::RunInstructionAsync(size_t instr_id) {
size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
std::queue<size_t> ready_ops; std::queue<size_t> ready_ops;
ready_ops.push(instr_id); ready_ops.push(instr_id);
while (!ready_ops.empty()) { while (!ready_ops.empty()) {
...@@ -965,7 +921,7 @@ void InterpreterCore::RunInstructionAsync( ...@@ -965,7 +921,7 @@ void InterpreterCore::RunInstructionAsync(
RunInstruction(instr_node); RunInstruction(instr_node);
CheckGC(instr_node, atomic_var_ref); CheckGC(instr_node);
interpreter::LogDeviceMemoryStats(place_); interpreter::LogDeviceMemoryStats(place_);
...@@ -1001,7 +957,7 @@ void InterpreterCore::RunInstructionAsync( ...@@ -1001,7 +957,7 @@ void InterpreterCore::RunInstructionAsync(
} }
} }
RunNextInstructions(instr_node, &ready_ops, atomic_deps, atomic_var_ref); RunNextInstructions(instr_node, &ready_ops);
} }
} }
...@@ -1100,9 +1056,7 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -1100,9 +1056,7 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
} }
#endif #endif
void InterpreterCore::CheckGC( void InterpreterCore::CheckGC(const Instruction& instr) {
const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
platform::RecordEvent record( platform::RecordEvent record(
"CheckGC", platform::TracerEventType::UserDefined, 10); "CheckGC", platform::TracerEventType::UserDefined, 10);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -1111,12 +1065,9 @@ void InterpreterCore::CheckGC( ...@@ -1111,12 +1065,9 @@ void InterpreterCore::CheckGC(
auto& var_scope = var_scope_; auto& var_scope = var_scope_;
for (auto var_id : instr.GCCheckVars()) { for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << var_scope_.GetNameById(var_id) << " " VLOG(4) << "GC:" << var_scope_.GetNameById(var_id) << ", id:" << var_id
<< var_scope.VarDesc(var_id); << ", ref:" << refs_[var_id]->DynamicRef();
VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id] bool is_ready = refs_[var_id]->CheckAndDecrease();
<< " " << var_id;
bool is_ready =
(*atomic_var_ref)[var_id].fetch_sub(1, std::memory_order_relaxed) == 1;
// ignore all persistable var while GC // ignore all persistable var while GC
if (var_scope.VarDesc(var_id) && var_scope.VarDesc(var_id)->Persistable()) { if (var_scope.VarDesc(var_id) && var_scope.VarDesc(var_id)->Persistable()) {
continue; continue;
...@@ -1124,7 +1075,7 @@ void InterpreterCore::CheckGC( ...@@ -1124,7 +1075,7 @@ void InterpreterCore::CheckGC(
if (is_ready) { if (is_ready) {
VLOG(6) << "Async delete variable with name : " VLOG(6) << "Async delete variable with name : "
<< var_scope.GetNameById(var_id); << var_scope.GetNameById(var_id);
gc_->Add(var_scope_.VarRef(var_id), instr); gc_->Add(refs_[var_id]->Var(), instr);
} }
} }
} }
......
...@@ -95,21 +95,17 @@ class InterpreterCore { ...@@ -95,21 +95,17 @@ class InterpreterCore {
void RecordStreamForGC(const Instruction& instr); void RecordStreamForGC(const Instruction& instr);
#endif #endif
void CheckGC(const Instruction& instr, void CheckGC(const Instruction& instr);
std::vector<std::atomic<size_t>>* atomic_var_ref);
void RunInstructionAsync(size_t instr_id, void RunInstructionAsync(size_t instr_id);
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops, std::queue<size_t>* reserved_next_ops);
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void BuildSkipShareLoDInfo(); void BuildSkipShareLoDInfo();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names); void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
private:
bool is_build_; bool is_build_;
platform::Place place_; platform::Place place_;
...@@ -142,6 +138,7 @@ class InterpreterCore { ...@@ -142,6 +138,7 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_; std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
...@@ -150,6 +147,9 @@ class InterpreterCore { ...@@ -150,6 +147,9 @@ class InterpreterCore {
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_; std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_; std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
}; };
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
...@@ -51,10 +51,7 @@ namespace interpreter { ...@@ -51,10 +51,7 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>; using VariableIdMap = std::map<std::string, std::vector<int>>;
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions( const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
size_t device_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter) {
std::vector<WorkQueueOptions> group_options; std::vector<WorkQueueOptions> group_options;
// for execute host Kernel // for execute host Kernel
group_options.emplace_back(/*name*/ "HostTasks", group_options.emplace_back(/*name*/ "HostTasks",
...@@ -72,24 +69,15 @@ const std::vector<WorkQueueOptions> ConstructWorkQueueOptions( ...@@ -72,24 +69,15 @@ const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
/*track_task*/ false, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
// for prepare deps and others
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ prepare_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
return group_options; return group_options;
} }
AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads, AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
size_t device_num_threads, size_t device_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter) EventsWaiter* waiter)
: host_num_thread_(host_num_threads) { : host_num_thread_(host_num_threads) {
queue_group_ = CreateWorkQueueGroup(ConstructWorkQueueOptions( queue_group_ = CreateWorkQueueGroup(
host_num_threads, device_num_threads, prepare_num_threads, waiter)); ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
} }
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
...@@ -104,44 +92,6 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, ...@@ -104,44 +92,6 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
} }
} }
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicDeps(const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicDeps, dependecy_count);
}
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicVarRef, vec_meta_info);
}
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
auto op_deps = std::make_unique<AtomicVectorSizeT>(dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i];
}
VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
return op_deps;
}
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
auto var_ref = std::make_unique<AtomicVectorSizeT>(vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_;
}
VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
return var_ref;
}
void LogDeviceMemoryStats(const platform::Place& place) { void LogDeviceMemoryStats(const platform::Place& place) {
if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) { if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) {
VLOG(0) << "memory_allocated: " VLOG(0) << "memory_allocated: "
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
using AtomicVectorSizeT = std::vector<std::atomic<size_t>>; using AtomicVectorSizeT = std::vector<std::atomic<size_t>>;
constexpr size_t kPrepareWorkQueueIdx = 2;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -48,14 +47,8 @@ class AsyncWorkQueue { ...@@ -48,14 +47,8 @@ class AsyncWorkQueue {
public: public:
AsyncWorkQueue(size_t host_num_threads, AsyncWorkQueue(size_t host_num_threads,
size_t deivce_num_threads, size_t deivce_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter); EventsWaiter* waiter);
std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
// void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } // void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
void AddTask(const OpFuncType& op_func_type, std::function<void()> fn); void AddTask(const OpFuncType& op_func_type, std::function<void()> fn);
...@@ -71,11 +64,6 @@ class AsyncWorkQueue { ...@@ -71,11 +64,6 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_; std::unique_ptr<WorkQueueGroup> queue_group_;
}; };
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
void LogDeviceMemoryStats(const platform::Place& place); void LogDeviceMemoryStats(const platform::Place& place);
void build_variable_scope(const framework::BlockDesc& block, void build_variable_scope(const framework::BlockDesc& block,
......
...@@ -720,16 +720,18 @@ OperatorBase* Instruction::OpBase() const { ...@@ -720,16 +720,18 @@ OperatorBase* Instruction::OpBase() const {
return op_base.get(); return op_base.get();
} }
NextInstruction& Instruction::NextInstructions() { return next_instruction_; } NextInstructionList& Instruction::NextInstructions() {
return next_instruction_;
}
const NextInstruction& Instruction::NextInstructions() const { const NextInstructionList& Instruction::NextInstructions() const {
return next_instruction_; return next_instruction_;
} }
void Instruction::AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); } void Instruction::AddGCCheckVar(size_t id) { gc_check_vars_.push_back(id); }
const std::vector<size_t>& Instruction::GCCheckVars() const { const std::vector<size_t>& Instruction::GCCheckVars() const {
return gc_check_var_list_; return gc_check_vars_;
} }
void Instruction::ResetContext(const VariableValueMap& in_vars, void Instruction::ResetContext(const VariableValueMap& in_vars,
......
...@@ -31,8 +31,6 @@ namespace paddle { ...@@ -31,8 +31,6 @@ namespace paddle {
namespace framework { namespace framework {
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>; using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
constexpr int kEmptyVarIndex = 0; constexpr int kEmptyVarIndex = 0;
...@@ -237,7 +235,7 @@ class VariableScope { ...@@ -237,7 +235,7 @@ class VariableScope {
std::vector<std::pair<std::string, int>> data_transfer_added_vars_; std::vector<std::pair<std::string, int>> data_transfer_added_vars_;
}; };
class NextInstruction { class NextInstructionList {
public: public:
void AddDirectRun(size_t id) { direct_run_.push_back(id); } void AddDirectRun(size_t id) { direct_run_.push_back(id); }
...@@ -267,10 +265,6 @@ struct EventInter { ...@@ -267,10 +265,6 @@ struct EventInter {
platform::DeviceType waiter_type_; platform::DeviceType waiter_type_;
}; };
struct InstructionInfo {
std::vector<size_t> dependecy_count_;
};
enum class OpFuncType { enum class OpFuncType {
kQueueSync = 0, // CPU kernel, block host kQueueSync = 0, // CPU kernel, block host
kQueueAsync = 1, // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast kQueueAsync = 1, // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast
...@@ -319,9 +313,9 @@ class Instruction { ...@@ -319,9 +313,9 @@ class Instruction {
OperatorBase* OpBase() const; OperatorBase* OpBase() const;
NextInstruction& NextInstructions(); NextInstructionList& NextInstructions();
const NextInstruction& NextInstructions() const; const NextInstructionList& NextInstructions() const;
void AddGCCheckVar(size_t id); void AddGCCheckVar(size_t id);
...@@ -370,8 +364,9 @@ class Instruction { ...@@ -370,8 +364,9 @@ class Instruction {
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_; std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_; std::shared_ptr<ExecutionContext> execution_ctx_;
std::vector<size_t> gc_check_var_list_; std::vector<size_t> gc_check_vars_;
NextInstruction next_instruction_;
NextInstructionList next_instruction_;
std::vector<EventInter> intput_events_; std::vector<EventInter> intput_events_;
std::vector<EventInter> output_events_; std::vector<EventInter> output_events_;
...@@ -403,6 +398,86 @@ static bool IsSupportedHetePlace(const phi::Place& place) { ...@@ -403,6 +398,86 @@ static bool IsSupportedHetePlace(const phi::Place& place) {
platform::is_custom_place(place); platform::is_custom_place(place);
} }
// static_ref_ is the numer of last live ops calculated to statically after
// `build` the Instructions. dynamic_ref_ is the runtime version ref which will
// be decreased by one dynamiclly after the execution of an op (in last ops
// list). var_ is the related variable
// The dynamic_ref_ is initialized to static_ref_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_ref_, i.e., dynamic_ref_ = static_ref_ see ResetAtomicGuard for
// details
class VarRefInfo {
public:
explicit VarRefInfo(size_t ref, Variable* var)
: static_ref_(ref), dynamic_ref_(ref), var_(var) {}
size_t DynamicRef() { return dynamic_ref_; }
Variable* Var() { return var_; }
void ResetDynamicRef() {
if (static_ref_ != 1) {
dynamic_ref_ = static_ref_;
}
}
bool CheckAndDecrease() {
return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1);
}
private:
const size_t static_ref_;
std::atomic<size_t> dynamic_ref_;
Variable* var_;
};
// static_dep_ is the numer of dependencies (ops that must run before it) of
// each op which is calculated to statically. static_dep_ is the runtime
// version dep which will be decreased by one dynamiclly after the execution of
// one dependency op.
// The dynamic_dep_ is initialized to static_dep_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_dep_, i.e., dynamic_dep_ = static_dep_ see ResetAtomicGuard for
// details
class OpDepInfo {
public:
explicit OpDepInfo(size_t dep) : static_dep_(dep), dynamic_dep_(dep) {}
size_t DynamicDep() { return dynamic_dep_; }
void ResetDynamicDep() {
if (static_dep_ != 1) {
dynamic_dep_ = static_dep_;
}
}
bool CheckAndDecrease() {
return static_dep_ == 1 || (dynamic_dep_.fetch_sub(1) == 1);
}
private:
const size_t static_dep_;
std::atomic<size_t> dynamic_dep_;
};
class ResetAtomicGuard {
public:
ResetAtomicGuard(std::vector<std::shared_ptr<OpDepInfo>>* deps,
std::vector<std::shared_ptr<VarRefInfo>>* refs)
: deps_(deps), refs_(refs) {}
~ResetAtomicGuard() {
VLOG(10) << "Reset DynamicDep";
for (auto&& dep : *deps_) {
dep->ResetDynamicDep();
}
VLOG(10) << "Reset DynamicRef";
for (auto&& ref : *refs_) {
ref->ResetDynamicRef();
}
}
private:
std::vector<std::shared_ptr<OpDepInfo>>* deps_;
std::vector<std::shared_ptr<VarRefInfo>>* refs_;
};
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册