未验证 提交 bb1fffd6 编写于 作者: Z zhangbo9674 提交者: GitHub

Add trace mode for interpretercore (#48370)

* add trace mode for interpretercore

* fix bug

* add a ctrl flag

* add record for memcpyd2h

* polish code

* polish code
上级 ab18644c
...@@ -206,7 +206,12 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -206,7 +206,12 @@ paddle::framework::FetchList InterpreterCore::Run(
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
} }
ExecuteInstructionList(vec_instruction_); if (execution_config_.used_for_jit && (sync_op_num_ == 0)) {
VLOG(4) << "Tracing Instruction List";
TraceInstructionList(vec_instruction_);
} else {
ExecuteInstructionList(vec_instruction_);
}
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place_)) { if (platform::is_npu_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
...@@ -257,6 +262,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -257,6 +262,7 @@ paddle::framework::FetchList InterpreterCore::Run(
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
is_build_ = true; is_build_ = true;
UpdateSyncOpNum();
} else { } else {
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
...@@ -268,7 +274,12 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -268,7 +274,12 @@ paddle::framework::FetchList InterpreterCore::Run(
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_); gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
} }
ExecuteInstructionList(vec_instruction_); if (execution_config_.used_for_jit && (sync_op_num_ == 0)) {
VLOG(4) << "Tracing Instruction List";
TraceInstructionList(vec_instruction_);
} else {
ExecuteInstructionList(vec_instruction_);
}
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place_)) { if (platform::is_npu_place(place_)) {
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
...@@ -719,6 +730,8 @@ void InterpreterCore::Convert( ...@@ -719,6 +730,8 @@ void InterpreterCore::Convert(
refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>( refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
vec_meta_info[i].var_ref_count_, var_scope_.VarRef(i))); vec_meta_info[i].var_ref_count_, var_scope_.VarRef(i)));
} }
AnalyseExecuteOrderForTrace();
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
...@@ -741,7 +754,7 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -741,7 +754,7 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
} }
} }
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunOperator(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
...@@ -865,6 +878,45 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -865,6 +878,45 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
} }
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
<< (instr_node.KernelType() == OpFuncType::kCpuSync
? "kCpuSync"
: (instr_node.KernelType() == OpFuncType::kGpuSync
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(
op->Type(), platform::TracerEventType::Operator, 1);
try {
instr_node.WaitEvent(place_);
if (!instr_node.IsArtificial()) {
RunOperator(instr_node);
CheckGC(instr_node);
interpreter::LogDeviceMemoryStats(place_);
}
instr_node.RecordEvent(place_);
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
} catch (platform::EOFException&) {
exception_holder_.Catch(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", " << ex.what();
exception_holder_.Catch(std::current_exception());
} catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception";
exception_holder_.Catch(std::current_exception());
}
}
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
interpreter::ResetAtomicGuard guard(&deps_, &refs_); interpreter::ResetAtomicGuard guard(&deps_, &refs_);
...@@ -879,18 +931,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -879,18 +931,7 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) { for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) { if (dependecy_count_[i] == 0) {
// NOTE(zhiqiu): hot fix for jit input var // NOTE(zhiqiu): hot fix for jit input var
if (vec_instr.at(i).OpBase()->Type() == interpreter::kMemcpyD2H) { RecordMemcpyD2H(vec_instr.at(i));
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* default_dev_ctx = pool.Get(place_);
for (auto& event : vec_instr.at(i).EventsToWait()) {
platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "Record event on default stream in jit_input_var at op: "
<< vec_instr.at(i).OpBase()->Type();
event.event_->Record(default_dev_ctx);
}
}
async_work_queue_->AddTask(vec_instr.at(i).KernelType(), async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this, i] { RunInstructionAsync(i); }); [this, i] { RunInstructionAsync(i); });
} }
...@@ -955,43 +996,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -955,43 +996,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
instr_id = ready_ops.front(); instr_id = ready_ops.front();
ready_ops.pop_front(); ready_ops.pop_front();
auto& instr_node = vec_instruction_.at(instr_id); auto& instr_node = vec_instruction_.at(instr_id);
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
<< (instr_node.KernelType() == OpFuncType::kCpuSync
? "kCpuSync"
: (instr_node.KernelType() == OpFuncType::kGpuSync
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(
op->Type(), platform::TracerEventType::Operator, 1);
try {
instr_node.WaitEvent(place_);
if (!instr_node.IsArtificial()) {
RunInstruction(instr_node);
CheckGC(instr_node);
interpreter::LogDeviceMemoryStats(place_);
}
instr_node.RecordEvent(place_); RunInstruction(instr_node);
} catch (platform::EnforceNotMet& ex) {
framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
} catch (platform::EOFException&) {
exception_holder_.Catch(std::current_exception());
} catch (std::exception& ex) {
LOG(WARNING) << op->Type() << " raises an exception "
<< platform::demangle(typeid(ex).name()) << ", "
<< ex.what();
exception_holder_.Catch(std::current_exception());
} catch (...) {
LOG(WARNING) << op->Type() << " raises an unknown exception";
exception_holder_.Catch(std::current_exception());
}
if (UNLIKELY(exception_holder_.IsCaught())) { if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught"; VLOG(4) << "Exception caught";
...@@ -1176,6 +1182,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1176,6 +1182,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
UpdateSyncOpNum();
is_build_ = true; is_build_ = true;
} }
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
...@@ -1213,5 +1220,140 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore( ...@@ -1213,5 +1220,140 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
return core; return core;
} }
// Note(zhangbo):
// (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a
// multi-threaded scheduling mode(see ExecuteInstructionList). By maintaining a
// high-performance thread pool, the OP's execute scheduling is distributed to
// the sub threads maintained by the thread pool, but the main thread does not
// have any tasks. In Trace mode, the executor will execute directly in the main
// thread according to the pre provided OP sequence(trace_execute_order_),
// instead of being distributed to the thread pool.
// (2) When we use "Trace"?
// In dygraph to static, This scheduling causes that the execution of the
// forward and backward OPs and the execution of the dygraph optimizer cannot be
// executed in the same thread. Executing thread switch may cause cpu cache
// miss. When a model is all KQueueAsync type OPs, all OPs will be distributed
// to the DeviceThread for execution, and the multithreading scheduling will not
// have any benefits. Therefore, in the dynamic to static, when the number of
// KQueueAsync Ops is 0, we choose Trace mode.
void InterpreterCore::TraceInstructionList(
const std::vector<Instruction>& vec_instr) {
unfinished_op_number_ = vec_instr.size();
if (unfinished_op_number_ == 0) {
VLOG(4) << "No op to run, return";
return;
}
exception_holder_.Clear();
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
// NOTE(zhiqiu): hot fix for jit input var
RecordMemcpyD2H(vec_instr.at(i));
}
}
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
auto instr_id = trace_execute_order_[idx];
auto& instr_node = vec_instruction_.at(instr_id);
RunInstruction(instr_node);
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught";
break;
}
}
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(1) << "Exception caught " << exception_holder_.Type();
PADDLE_ENFORCE_EQ(
main_thread_blocker_.Clear(),
0,
platform::errors::PreconditionNotMet(
"main_thread_blocker_.Clear() return -1, clear failed"));
VLOG(4) << "clear ok";
exception_holder_.ReThrow();
}
}
void InterpreterCore::RecordMemcpyD2H(const Instruction& instr_node) {
// NOTE(zhiqiu): hot fix for jit input var
if (instr_node.OpBase()->Type() == interpreter::kMemcpyD2H) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* default_dev_ctx = pool.Get(place_);
for (auto& event : instr_node.EventsToWait()) {
platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "Record event on default stream in jit_input_var at op: "
<< instr_node.OpBase()->Type();
event.event_->Record(default_dev_ctx);
}
}
}
void InterpreterCore::UpdateSyncOpNum() {
int64_t sync_op_num = 0;
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
if (vec_instruction_[i].KernelType() == OpFuncType::kCpuSync ||
vec_instruction_[i].KernelType() == OpFuncType::kGpuSync) {
sync_op_num = sync_op_num + 1;
}
}
sync_op_num_ = sync_op_num;
VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_;
}
// Note(zhangbo):
// When there is a KQueueSync type OP in the model, breadth traversal is better
// than depth traversal. For example: OP(O) ->(direct_run)-> OP(A)
// ->(sync_run)-> OP(B) OP(O) ->(direct_run)-> OP(C) ->(direct_run)-> OP(D) If B
// is run before C, B may always block to wait for A to finish executing, but in
// fact, C can be executed first during this time.
void InterpreterCore::AnalyseExecuteOrderForTrace() {
VLOG(4) << "Analyze the execution order of Trace scheduling mode.";
interpreter::ResetAtomicGuard guard(&deps_, &refs_);
auto op_downstream_map = dependency_builder_.OpDownstreamMap();
auto IsReady = [this](size_t next_id) {
VLOG(4) << "op_id: " << next_id
<< ", remain deps: " << deps_[next_id]->DynamicDep();
return deps_[next_id]->CheckAndDecrease();
};
std::vector<size_t> trace_order;
std::deque<size_t> ready_ops;
for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) {
if (dependecy_count_[instr_id] == 0) {
ready_ops.push_back(instr_id);
}
}
while (!ready_ops.empty()) {
auto now_id = ready_ops.front();
ready_ops.pop_front();
trace_order.push_back(now_id);
auto next_op_set = op_downstream_map[now_id];
for (size_t next_op_id : next_op_set) {
if (IsReady(next_op_id)) {
ready_ops.push_back(next_op_id);
}
}
}
PADDLE_ENFORCE_EQ(
trace_order.size(),
dependecy_count_.size(),
platform::errors::PreconditionNotMet(
"trace_order size should be equal to dependecy_count_."));
trace_execute_order_ = trace_order;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -83,6 +83,8 @@ class InterpreterCore { ...@@ -83,6 +83,8 @@ class InterpreterCore {
void BuildOperatorDependences(); void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node); void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo(); void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
// inplace // inplace
void BuildInplace(); void BuildInplace();
...@@ -96,11 +98,17 @@ class InterpreterCore { ...@@ -96,11 +98,17 @@ class InterpreterCore {
void RunInstruction(const Instruction& instr_node); void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
std::deque<size_t>* reserved_next_ops); std::deque<size_t>* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op // only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names, void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors, const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed); bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc // gc
void RecordStreamForGC(const Instruction& instr); void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr); void CheckGC(const Instruction& instr);
...@@ -159,7 +167,9 @@ class InterpreterCore { ...@@ -159,7 +167,9 @@ class InterpreterCore {
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_; std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_; std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
// for jit // used for Trace
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
}; };
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册