“f444865bde167e4396f70627915f859430e7b157”上不存在“...doc/howto/usage/cmd_parameter/detail_introduction_en.html”
未验证 提交 dd3d45de 编写于 作者: L Leo Chen 提交者: GitHub

reduce time cost on atomic in interpretercore (#46688)

* reduce time cost on atomic in interpretercore

* clear code of PrepareAtomic in interpretercore

* refine threadpool cache
上级 c333af2f
......@@ -326,7 +326,9 @@ inline void RunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) {
VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
VLOG(2) << "No interpretercore cahce, so create a new interpretercore "
"for program: "
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope);
......@@ -545,19 +547,14 @@ inline void RunProgramGradAPI(
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
PADDLE_ENFORCE_EQ(
interpretercore_info_cache.Has(program_id, false),
true,
paddle::platform::errors::NotFound(
"The forward interpretercore of program %d is not found",
program_id));
auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
<< interpreter_core.get();
if (interpretercore_info_cache.Has(program_id, false)) {
auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get()
<< " to " << interpreter_core.get();
}
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
// all out_vars are skip_eager_var
......
......@@ -295,6 +295,8 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
auto &interpretercore_info_cache =
framework::InterpreterCoreInfoCache::Instance();
if (interpretercore_info_cache.Size() > 4u /* max_cached_size*/) {
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
"all cache!";
interpretercore_info_cache.Finalize();
}
auto core = std::make_shared<InterpreterCore>(
......
......@@ -21,11 +21,13 @@ namespace framework {
namespace interpreter {
void WaitEvent(const Instruction& instruction, const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
if (platform::is_cpu_place(place)) {
return;
}
VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type();
for (auto& event_iter : instruction.InputEvents()) {
for (const auto& event_iter : instruction.InputEvents()) {
platform::RecordEvent record(
"WaitStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "wait var_id: " << event_iter.var_id_
......@@ -37,9 +39,11 @@ void WaitEvent(const Instruction& instruction, const platform::Place& place) {
void RecordEvent(const Instruction& instruction, const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
if (platform::is_cpu_place(place)) {
return;
}
for (auto& event : instruction.OutputEvents()) {
for (const auto& event : instruction.OutputEvents()) {
platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
......
......@@ -32,32 +32,19 @@ namespace interpreter {
static constexpr size_t kHostNumThreads = 4;
static constexpr size_t kDeviceNumThreads = 1;
static constexpr size_t kNumGcThreads = 1;
static constexpr size_t kNumPrepareThreads = 0;
static constexpr size_t kMinOpNumForAsyncPrepare = 1000;
// By default, one interpretercore contains:
// 1-size thread pool for device kernel launch (or 0 for cpu execution),
// 1-size thread pool for host kernel launch (or more if the system contains
// enough processors).
// And it may contain:
// 1-size thread pool for gc if it is can not use FastGC,
// 1-size thread pool for preparation if the program contains two many ops
// (1000+).
// Note that the purpose of the config is to limit the total 'possible'
// threads introduced by interpretercore to avoid hurting performance.
inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place,
size_t op_num) {
inline std::tuple<int, int> GetThreadPoolConfig(const phi::Place& place,
size_t op_num) {
int num_device_threads = kDeviceNumThreads,
num_host_threads = kHostNumThreads,
num_prepare_threads = kNumPrepareThreads;
if (op_num > kMinOpNumForAsyncPrepare) {
num_prepare_threads = 1;
}
num_host_threads = kHostNumThreads;
int device_count = 0, processor_count = 0;
if (platform::is_cpu_place(place)) {
......@@ -109,7 +96,7 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place,
if (device_count) {
auto num = processor_count / device_count / 2 -
(kNumGcThreads + kNumPrepareThreads + num_device_threads);
(kNumGcThreads + num_device_threads);
num_host_threads =
num > 0 ? (num > kHostNumThreads ? kHostNumThreads : num) : 1;
}
......@@ -126,14 +113,13 @@ inline std::tuple<int, int, int> GetThreadPoolConfig(const phi::Place& place,
<< ", device_count:" << device_count
<< ", serial_run:" << FLAGS_new_executor_serial_run
<< ", num_host_threads:" << num_host_threads
<< ", num_device_threads:" << num_device_threads
<< ", num_prepare_threads:" << num_prepare_threads;
return std::make_tuple(
num_host_threads, num_device_threads, num_prepare_threads);
<< ", num_device_threads:" << num_device_threads;
return std::make_tuple(num_host_threads, num_device_threads);
}
ExecutionConfig::ExecutionConfig(const phi::Place& place, size_t op_num) {
std::tie(host_num_threads, deivce_num_threads, prepare_num_threads) =
std::tie(host_num_threads, deivce_num_threads) =
GetThreadPoolConfig(place, op_num);
}
......@@ -143,7 +129,6 @@ void ExecutionConfig::Log(int log_level) {
VLOG(log_level) << "create_local_scope = " << create_local_scope;
VLOG(log_level) << "host_num_threads = " << host_num_threads;
VLOG(log_level) << "deivce_num_threads = " << deivce_num_threads;
VLOG(log_level) << "prepare_num_threads = " << prepare_num_threads;
VLOG(log_level) << "skip_gc_vars = ";
for (const std::string& var : skip_gc_vars) {
VLOG(log_level) << var;
......
......@@ -29,7 +29,6 @@ struct ExecutionConfig {
size_t host_num_threads;
size_t deivce_num_threads;
size_t prepare_num_threads;
std::set<std::string> skip_gc_vars;
......
......@@ -321,7 +321,6 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
execution_config_.host_num_threads,
execution_config_.deivce_num_threads,
execution_config_.prepare_num_threads,
&main_thread_blocker_);
}
return async_work_queue_;
......@@ -601,17 +600,13 @@ void InterpreterCore::Convert(
BuildInplace();
}
// prepare for the first time.
std::promise<std::unique_ptr<AtomicVectorSizeT>> deps_promise =
std::promise<std::unique_ptr<AtomicVectorSizeT>>();
atomic_deps_ = deps_promise.get_future();
deps_promise.set_value(interpreter::PrepareAtomicDeps(dependecy_count_));
std::promise<std::unique_ptr<AtomicVectorSizeT>> var_ref_promise =
std::promise<std::unique_ptr<AtomicVectorSizeT>>();
atomic_var_ref_ = var_ref_promise.get_future();
var_ref_promise.set_value(
interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo()));
for (auto& dep : dependecy_count_) {
deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
}
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
vec_meta_info[i].var_ref_count_, var_scope_.VarRef(i)));
}
}
void InterpreterCore::BuildSkipShareLoDInfo() {
......@@ -804,45 +799,19 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) {
interpreter::ResetAtomicGuard guard(&deps_, &refs_);
unfinished_op_numer_ = vec_instr.size();
if (unfinished_op_numer_ == 0) {
VLOG(4) << "No op to run, return";
return;
}
platform::RecordEvent record_prepare(
"PrepareAtomic", platform::TracerEventType::UserDefined, 1);
std::unique_ptr<std::vector<std::atomic<size_t>>> atomic_deps = nullptr;
std::unique_ptr<std::vector<std::atomic<size_t>>> atomic_var_ref = nullptr;
if (async_work_queue_->QueueNumThreads(kPrepareWorkQueueIdx)) {
// NOTE(zhiqiu): get the prepared deps from std::future, and async prepare
// those for the next step
atomic_deps = atomic_deps_.get();
atomic_var_ref = atomic_var_ref_.get();
atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_);
atomic_var_ref_ =
async_work_queue_->PrepareAtomicVarRef(var_scope_.VecMetaInfo());
} else {
atomic_deps = interpreter::PrepareAtomicDeps(dependecy_count_);
atomic_var_ref = interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo());
}
record_prepare.End();
exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this,
i,
atomic_deps = atomic_deps.get(),
atomic_var_ref = atomic_var_ref.get()] {
RunInstructionAsync(
i, atomic_deps, atomic_var_ref);
});
[this, i] { RunInstructionAsync(i); });
}
}
......@@ -869,19 +838,15 @@ void InterpreterCore::ExecuteInstructionList(
}
void InterpreterCore::RunNextInstructions(
const Instruction& instr,
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10);
VLOG(4) << "atomic 1:" << atomic_deps;
auto& next_instr = instr.NextInstructions();
auto IsReady = [atomic_deps](size_t next_id) {
VLOG(4) << "atomic:" << atomic_deps << " op_id: " << next_id
<< ", remain deps: " << (*atomic_deps)[next_id];
return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1;
auto IsReady = [this](size_t next_id) {
VLOG(4) << "op_id: " << next_id
<< ", remain deps: " << deps_[next_id]->DynamicDep();
return deps_[next_id]->CheckAndDecrease();
};
if (instr.KernelType() == OpFuncType::kQueueAsync) {
......@@ -890,9 +855,7 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref]() {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
[this, next_id]() { RunInstructionAsync(next_id); });
}
}
// keep all async_ops running in current thread
......@@ -912,9 +875,7 @@ void InterpreterCore::RunNextInstructions(
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref] {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
[this, next_id] { RunInstructionAsync(next_id); });
}
}
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
......@@ -930,19 +891,14 @@ void InterpreterCore::RunNextInstructions(
// move rest ops into other threads
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[this, next_id, atomic_deps, atomic_var_ref] {
RunInstructionAsync(next_id, atomic_deps, atomic_var_ref);
});
[this, next_id] { RunInstructionAsync(next_id); });
}
}
if (first_op != -1) reserved_next_ops->push(first_op);
}
}
void InterpreterCore::RunInstructionAsync(
size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
void InterpreterCore::RunInstructionAsync(size_t instr_id) {
std::queue<size_t> ready_ops;
ready_ops.push(instr_id);
while (!ready_ops.empty()) {
......@@ -965,7 +921,7 @@ void InterpreterCore::RunInstructionAsync(
RunInstruction(instr_node);
CheckGC(instr_node, atomic_var_ref);
CheckGC(instr_node);
interpreter::LogDeviceMemoryStats(place_);
......@@ -1001,7 +957,7 @@ void InterpreterCore::RunInstructionAsync(
}
}
RunNextInstructions(instr_node, &ready_ops, atomic_deps, atomic_var_ref);
RunNextInstructions(instr_node, &ready_ops);
}
}
......@@ -1100,9 +1056,7 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
}
#endif
void InterpreterCore::CheckGC(
const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
void InterpreterCore::CheckGC(const Instruction& instr) {
platform::RecordEvent record(
"CheckGC", platform::TracerEventType::UserDefined, 10);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -1111,12 +1065,9 @@ void InterpreterCore::CheckGC(
auto& var_scope = var_scope_;
for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << var_scope_.GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id);
VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id]
<< " " << var_id;
bool is_ready =
(*atomic_var_ref)[var_id].fetch_sub(1, std::memory_order_relaxed) == 1;
VLOG(4) << "GC:" << var_scope_.GetNameById(var_id) << ", id:" << var_id
<< ", ref:" << refs_[var_id]->DynamicRef();
bool is_ready = refs_[var_id]->CheckAndDecrease();
// ignore all persistable var while GC
if (var_scope.VarDesc(var_id) && var_scope.VarDesc(var_id)->Persistable()) {
continue;
......@@ -1124,7 +1075,7 @@ void InterpreterCore::CheckGC(
if (is_ready) {
VLOG(6) << "Async delete variable with name : "
<< var_scope.GetNameById(var_id);
gc_->Add(var_scope_.VarRef(var_id), instr);
gc_->Add(refs_[var_id]->Var(), instr);
}
}
}
......
......@@ -95,21 +95,17 @@ class InterpreterCore {
void RecordStreamForGC(const Instruction& instr);
#endif
void CheckGC(const Instruction& instr,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void CheckGC(const Instruction& instr);
void RunInstructionAsync(size_t instr_id,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref);
std::queue<size_t>* reserved_next_ops);
void BuildSkipShareLoDInfo();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
private:
bool is_build_;
platform::Place place_;
......@@ -142,6 +138,7 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
......@@ -150,6 +147,9 @@ class InterpreterCore {
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
};
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
......@@ -51,10 +51,7 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads,
size_t device_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter) {
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
std::vector<WorkQueueOptions> group_options;
// for execute host Kernel
group_options.emplace_back(/*name*/ "HostTasks",
......@@ -72,24 +69,15 @@ const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
// for prepare deps and others
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ prepare_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
return group_options;
}
AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
size_t device_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter)
: host_num_thread_(host_num_threads) {
queue_group_ = CreateWorkQueueGroup(ConstructWorkQueueOptions(
host_num_threads, device_num_threads, prepare_num_threads, waiter));
queue_group_ = CreateWorkQueueGroup(
ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
}
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
......@@ -104,44 +92,6 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
}
}
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicDeps(const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicDeps, dependecy_count);
}
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
return queue_group_->AddAwaitableTask(
kPrepareWorkQueueIdx, interpreter::PrepareAtomicVarRef, vec_meta_info);
}
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) {
VLOG(4) << "PrepareAtomicDeps";
auto op_deps = std::make_unique<AtomicVectorSizeT>(dependecy_count.size());
for (size_t i = 0; i < dependecy_count.size(); ++i) {
(*op_deps)[i] = dependecy_count[i];
}
VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
return op_deps;
}
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info) {
VLOG(4) << "PrepareAtomicVarRef";
auto var_ref = std::make_unique<AtomicVectorSizeT>(vec_meta_info.size());
for (size_t i = 0; i < vec_meta_info.size(); ++i) {
(*var_ref)[i] = vec_meta_info[i].var_ref_count_;
}
VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
return var_ref;
}
void LogDeviceMemoryStats(const platform::Place& place) {
if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) {
VLOG(0) << "memory_allocated: "
......
......@@ -39,7 +39,6 @@
#include "paddle/fluid/platform/init.h"
using AtomicVectorSizeT = std::vector<std::atomic<size_t>>;
constexpr size_t kPrepareWorkQueueIdx = 2;
namespace paddle {
namespace framework {
......@@ -48,14 +47,8 @@ class AsyncWorkQueue {
public:
AsyncWorkQueue(size_t host_num_threads,
size_t deivce_num_threads,
size_t prepare_num_threads,
EventsWaiter* waiter);
std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
std::future<std::unique_ptr<AtomicVectorSizeT>> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
// void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); }
void AddTask(const OpFuncType& op_func_type, std::function<void()> fn);
......@@ -71,11 +64,6 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_;
};
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count);
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
const std::vector<VariableMetaInfo>& vec_meta_info);
void LogDeviceMemoryStats(const platform::Place& place);
void build_variable_scope(const framework::BlockDesc& block,
......
......@@ -720,16 +720,18 @@ OperatorBase* Instruction::OpBase() const {
return op_base.get();
}
NextInstruction& Instruction::NextInstructions() { return next_instruction_; }
NextInstructionList& Instruction::NextInstructions() {
return next_instruction_;
}
const NextInstruction& Instruction::NextInstructions() const {
const NextInstructionList& Instruction::NextInstructions() const {
return next_instruction_;
}
void Instruction::AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
void Instruction::AddGCCheckVar(size_t id) { gc_check_vars_.push_back(id); }
const std::vector<size_t>& Instruction::GCCheckVars() const {
return gc_check_var_list_;
return gc_check_vars_;
}
void Instruction::ResetContext(const VariableValueMap& in_vars,
......
......@@ -31,8 +31,6 @@ namespace paddle {
namespace framework {
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
constexpr int kEmptyVarIndex = 0;
......@@ -237,7 +235,7 @@ class VariableScope {
std::vector<std::pair<std::string, int>> data_transfer_added_vars_;
};
class NextInstruction {
class NextInstructionList {
public:
void AddDirectRun(size_t id) { direct_run_.push_back(id); }
......@@ -267,10 +265,6 @@ struct EventInter {
platform::DeviceType waiter_type_;
};
struct InstructionInfo {
std::vector<size_t> dependecy_count_;
};
enum class OpFuncType {
kQueueSync = 0, // CPU kernel, block host
kQueueAsync = 1, // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast
......@@ -319,9 +313,9 @@ class Instruction {
OperatorBase* OpBase() const;
NextInstruction& NextInstructions();
NextInstructionList& NextInstructions();
const NextInstruction& NextInstructions() const;
const NextInstructionList& NextInstructions() const;
void AddGCCheckVar(size_t id);
......@@ -370,8 +364,9 @@ class Instruction {
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_;
std::vector<size_t> gc_check_var_list_;
NextInstruction next_instruction_;
std::vector<size_t> gc_check_vars_;
NextInstructionList next_instruction_;
std::vector<EventInter> intput_events_;
std::vector<EventInter> output_events_;
......@@ -403,6 +398,86 @@ static bool IsSupportedHetePlace(const phi::Place& place) {
platform::is_custom_place(place);
}
// static_ref_ is the numer of last live ops calculated to statically after
// `build` the Instructions. dynamic_ref_ is the runtime version ref which will
// be decreased by one dynamiclly after the execution of an op (in last ops
// list). var_ is the related variable
// The dynamic_ref_ is initialized to static_ref_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_ref_, i.e., dynamic_ref_ = static_ref_ see ResetAtomicGuard for
// details
class VarRefInfo {
public:
explicit VarRefInfo(size_t ref, Variable* var)
: static_ref_(ref), dynamic_ref_(ref), var_(var) {}
size_t DynamicRef() { return dynamic_ref_; }
Variable* Var() { return var_; }
void ResetDynamicRef() {
if (static_ref_ != 1) {
dynamic_ref_ = static_ref_;
}
}
bool CheckAndDecrease() {
return static_ref_ == 1 || (dynamic_ref_.fetch_sub(1) == 1);
}
private:
const size_t static_ref_;
std::atomic<size_t> dynamic_ref_;
Variable* var_;
};
// static_dep_ is the numer of dependencies (ops that must run before it) of
// each op which is calculated to statically. static_dep_ is the runtime
// version dep which will be decreased by one dynamiclly after the execution of
// one dependency op.
// The dynamic_dep_ is initialized to static_dep_ first, and is decreased to 1
// during interpretercore's execution, after the interpretercore run, it `reset`
// all dynamic_dep_, i.e., dynamic_dep_ = static_dep_ see ResetAtomicGuard for
// details
class OpDepInfo {
public:
explicit OpDepInfo(size_t dep) : static_dep_(dep), dynamic_dep_(dep) {}
size_t DynamicDep() { return dynamic_dep_; }
void ResetDynamicDep() {
if (static_dep_ != 1) {
dynamic_dep_ = static_dep_;
}
}
bool CheckAndDecrease() {
return static_dep_ == 1 || (dynamic_dep_.fetch_sub(1) == 1);
}
private:
const size_t static_dep_;
std::atomic<size_t> dynamic_dep_;
};
class ResetAtomicGuard {
public:
ResetAtomicGuard(std::vector<std::shared_ptr<OpDepInfo>>* deps,
std::vector<std::shared_ptr<VarRefInfo>>* refs)
: deps_(deps), refs_(refs) {}
~ResetAtomicGuard() {
VLOG(10) << "Reset DynamicDep";
for (auto&& dep : *deps_) {
dep->ResetDynamicDep();
}
VLOG(10) << "Reset DynamicRef";
for (auto&& ref : *refs_) {
ref->ResetDynamicRef();
}
}
private:
std::vector<std::shared_ptr<OpDepInfo>>* deps_;
std::vector<std::shared_ptr<VarRefInfo>>* refs_;
};
} // namespace interpreter
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册