未验证 提交 5303b66b 编写于 作者: L Leo Chen 提交者: GitHub

clean code of interpretercore (#46891)

* refactor

* refine code
上级 21fab90d
...@@ -378,7 +378,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -378,7 +378,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
"Required src_place shall be different with dst_place, " "Required src_place shall be different with dst_place, "
"but received same place: %s", "but received same place: %s",
src_place)); src_place));
if (IsSupportedHetePlace(dst_place)) { if (IsSupportedHeterPlace(dst_place)) {
op_type = kMemcpyH2D; op_type = kMemcpyH2D;
int dst_place_type = platform::is_gpu_place(dst_place) ? 0 int dst_place_type = platform::is_gpu_place(dst_place) ? 0
: platform::is_npu_place(dst_place) ? 1 : platform::is_npu_place(dst_place) ? 1
...@@ -387,7 +387,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -387,7 +387,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
: platform::is_custom_place(dst_place) ? 6 : platform::is_custom_place(dst_place) ? 6
: -1; : -1;
attr_map = {{"dst_place_type", dst_place_type}}; attr_map = {{"dst_place_type", dst_place_type}};
} else if (IsSupportedHetePlace(src_place)) { } else if (IsSupportedHeterPlace(src_place)) {
op_type = kMemcpyD2H; op_type = kMemcpyD2H;
int dst_place_type = platform::is_cpu_place(dst_place) ? 0 int dst_place_type = platform::is_cpu_place(dst_place) ? 0
: platform::is_cuda_pinned_place(dst_place) ? 1 : platform::is_cuda_pinned_place(dst_place) ? 1
......
...@@ -57,6 +57,50 @@ constexpr const char* kTaskCompletion = "TaskCompletion"; ...@@ -57,6 +57,50 @@ constexpr const char* kTaskCompletion = "TaskCompletion";
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
// TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by // TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by
// constructing an interpreter::ExecutionConfig // constructing an interpreter::ExecutionConfig
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place,
...@@ -71,8 +115,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -71,8 +115,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
stream_analyzer_(place) { stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_; VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
is_build_ = false;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
...@@ -87,12 +129,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -87,12 +129,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
local_scope_ = local_scope; local_scope_ = local_scope;
} }
var_scope_.SetLocalScope(local_scope_); var_scope_.SetLocalScope(local_scope_);
// prune
// optmize graph pass
// convert to run graph
} }
InterpreterCore::~InterpreterCore() { InterpreterCore::~InterpreterCore() {
...@@ -111,11 +147,8 @@ InterpreterCore::~InterpreterCore() { ...@@ -111,11 +147,8 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) { const std::vector<phi::DenseTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) SetDeviceId(place_);
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
Prepare(feed_names, feed_tensors, true); Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info; interpreter::CostInfo cost_info;
{ {
...@@ -135,7 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun( ...@@ -135,7 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun(
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
} }
if (execution_config_.create_local_scope) { if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
...@@ -145,11 +178,7 @@ interpreter::CostInfo InterpreterCore::DryRun( ...@@ -145,11 +178,7 @@ interpreter::CostInfo InterpreterCore::DryRun(
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) { const std::vector<phi::DenseTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) SetDeviceId(place_);
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
...@@ -181,7 +210,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -181,7 +210,7 @@ paddle::framework::FetchList InterpreterCore::Run(
} }
#endif #endif
} }
if (execution_config_.create_local_scope) { if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
...@@ -196,11 +225,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -196,11 +225,7 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) { const std::vector<std::string>& feed_names) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) SetDeviceId(place_);
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
...@@ -208,17 +233,17 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -208,17 +233,17 @@ paddle::framework::FetchList InterpreterCore::Run(
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::build_variable_scope( paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, execution_config_.create_local_scope); block_, &var_scope_, HasLocalScope());
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::BuildOpFuncList(
place_, place_,
block_, block_,
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
execution_config_.create_local_scope, HasLocalScope(),
execution_config_.used_for_jit); execution_config_.used_for_jit);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -248,13 +273,13 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -248,13 +273,13 @@ paddle::framework::FetchList InterpreterCore::Run(
#endif #endif
} }
if (execution_config_.create_local_scope) { if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// return Fetch Tensors // return Fetch Tensors
Scope* inner_scope = execution_config_.create_local_scope Scope* inner_scope =
? local_scope_ HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var) { if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -327,9 +352,8 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() { ...@@ -327,9 +352,8 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
} }
void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope = execution_config_.create_local_scope Scope* inner_scope =
? local_scope_ HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) { for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
...@@ -355,9 +379,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -355,9 +379,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
// set runtime_ctx and infershape_ctx_ // set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel // kernel
Scope* local_scope = execution_config_.create_local_scope Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else { } else {
instr_node->ResetContext(ins_map, outs_map); instr_node->ResetContext(ins_map, outs_map);
...@@ -387,9 +410,8 @@ void InterpreterCore::BuildInplace() { ...@@ -387,9 +410,8 @@ void InterpreterCore::BuildInplace() {
} }
} }
Scope* local_scope = execution_config_.create_local_scope Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize()); std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) { for (Instruction& instr : vec_instruction_) {
for (auto& item : instr.Inputs()) { for (auto& item : instr.Inputs()) {
...@@ -524,9 +546,8 @@ void InterpreterCore::Convert( ...@@ -524,9 +546,8 @@ void InterpreterCore::Convert(
} }
for (auto var_id : gc_check_vars) { for (auto var_id : gc_check_vars) {
Scope* inner_scope = execution_config_.create_local_scope Scope* inner_scope =
? local_scope_ HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
paddle::framework::Variable* var = paddle::framework::Variable* var =
inner_scope->FindVar(var_scope_.GetNameById(var_id)); inner_scope->FindVar(var_scope_.GetNameById(var_id));
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() || if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
...@@ -629,56 +650,11 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -629,56 +650,11 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
} }
} }
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = execution_config_.create_local_scope Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope();
: var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
SetDeviceId(place); SetDeviceId(place);
...@@ -800,8 +776,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -800,8 +776,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) { const std::vector<Instruction>& vec_instr) {
interpreter::ResetAtomicGuard guard(&deps_, &refs_); interpreter::ResetAtomicGuard guard(&deps_, &refs_);
unfinished_op_numer_ = vec_instr.size(); unfinished_op_number_ = vec_instr.size();
if (unfinished_op_numer_ == 0) { if (unfinished_op_number_ == 0) {
VLOG(4) << "No op to run, return"; VLOG(4) << "No op to run, return";
return; return;
} }
...@@ -878,8 +854,12 @@ void InterpreterCore::RunNextInstructions( ...@@ -878,8 +854,12 @@ void InterpreterCore::RunNextInstructions(
[this, next_id] { RunInstructionAsync(next_id); }); [this, next_id] { RunInstructionAsync(next_id); });
} }
} }
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
next_instr.DirectRunIds()); std::vector<size_t> direct_run_ops = next_instr.SyncRunIds();
direct_run_ops.insert(direct_run_ops.end(),
next_instr.DirectRunIds().begin(),
next_instr.DirectRunIds().end());
int64_t first_op = -1; int64_t first_op = -1;
for (auto next_id : direct_run_ops) { for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
...@@ -949,9 +929,9 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -949,9 +929,9 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return; return;
} }
VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_; VLOG(4) << "unfinished_op_number_: " << unfinished_op_number_;
if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) == if (UNLIKELY(unfinished_op_number_.fetch_sub(
1)) { 1, std::memory_order_relaxed) == 1)) {
if (completion_notifier_ != nullptr) { if (completion_notifier_ != nullptr) {
completion_notifier_->NotifyEvent(); completion_notifier_->NotifyEvent();
} }
...@@ -961,8 +941,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -961,8 +941,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
} }
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void InterpreterCore::RecordStreamForGC(const Instruction& instr) { void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unimplemented(
"RecordStreamForGC is only implemented when compiled with GPU."));
#else
if (!IsInterpretercoreFastGCEnabled() || if (!IsInterpretercoreFastGCEnabled() ||
instr.KernelType() != OpFuncType::kQueueAsync) { instr.KernelType() != OpFuncType::kQueueAsync) {
return; return;
...@@ -1053,8 +1036,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -1053,8 +1036,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} }
}
#endif #endif
}
void InterpreterCore::CheckGC(const Instruction& instr) { void InterpreterCore::CheckGC(const Instruction& instr) {
platform::RecordEvent record( platform::RecordEvent record(
...@@ -1106,17 +1089,17 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1106,17 +1089,17 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
}; };
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope( paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, execution_config_.create_local_scope); block_, &var_scope_, HasLocalScope());
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::BuildOpFuncList(
place_, place_,
block_, block_,
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
execution_config_.create_local_scope, HasLocalScope(),
execution_config_.used_for_jit); execution_config_.used_for_jit);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -1124,7 +1107,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1124,7 +1107,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
Convert(&op_func_nodes); Convert(&op_func_nodes);
} }
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should // paddle::framework::BuildOpFuncList, so we should
// call FeedInput again. // call FeedInput again.
if (prepare_feed) { if (prepare_feed) {
FeedInput(); FeedInput();
...@@ -1138,6 +1121,8 @@ void InterpreterCore::SetFeedVarsInplaceSkip( ...@@ -1138,6 +1121,8 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
} }
} }
bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; }
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const platform::Place& place,
const ProgramDesc& prog, const ProgramDesc& prog,
...@@ -1145,11 +1130,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore( ...@@ -1145,11 +1130,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) { const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy // NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program. // a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog); auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0); auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block); interpreter::AddFetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope); core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
core->SetCopyProgram(new_prog); core->SetCopyProgram(new_prog);
......
...@@ -68,45 +68,42 @@ class InterpreterCore { ...@@ -68,45 +68,42 @@ class InterpreterCore {
void reset_scope(Scope* new_scope); void reset_scope(Scope* new_scope);
private: private:
bool BuildInplaceCheckVarIsOnlyInput( // build graph
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
void BuildAndCacheInstructionCtx(Instruction* instr_node); void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
// inplace
void BuildInplace(); void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
void BuildOperatorDependences(); // execution
void ClearLoDTensorArrayInLocalScope();
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void RunInstruction(const Instruction& instr_node);
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr); void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names, void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors, const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed); bool prepare_feed);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // gc
void RecordStreamForGC(const Instruction& instr); void RecordStreamForGC(const Instruction& instr);
#endif
void CheckGC(const Instruction& instr); void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
void RunInstructionAsync(size_t instr_id); // workqueue
void RunNextInstructions(const Instruction& instr_id, std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
std::queue<size_t>* reserved_next_ops);
void BuildSkipShareLoDInfo();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names); // scope
bool HasLocalScope() const;
private: private:
bool is_build_; bool is_build_{false};
platform::Place place_; platform::Place place_;
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
...@@ -127,11 +124,7 @@ class InterpreterCore { ...@@ -127,11 +124,7 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
// last_live_ops_[i] contains the id of operators that last access var[i] std::atomic<size_t> unfinished_op_number_{0};
std::map<size_t, std::set<size_t>> last_live_ops_;
std::vector<size_t> dependecy_count_;
std::atomic<size_t> unfinished_op_numer_{0};
VariableScope var_scope_; VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
...@@ -145,8 +138,13 @@ class InterpreterCore { ...@@ -145,8 +138,13 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_; // last_live_ops_[i] contains the id of operators that last access the i-th
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_; // var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_; std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_; std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
......
...@@ -122,8 +122,8 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { ...@@ -122,8 +122,8 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
std::unordered_map<const paddle::framework::OperatorBase*, std::unordered_map<const paddle::framework::OperatorBase*,
std::vector<std::string>> std::vector<std::string>>
get_unused_vars(const BlockDesc& block, GetUnusedVars(const BlockDesc& block,
const std::vector<std::shared_ptr<OperatorBase>>& ops) { const std::vector<std::shared_ptr<OperatorBase>>& ops) {
std::unordered_map<std::string, size_t> var_op_idx_map; std::unordered_map<std::string, size_t> var_op_idx_map;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
...@@ -166,17 +166,17 @@ get_unused_vars(const BlockDesc& block, ...@@ -166,17 +166,17 @@ get_unused_vars(const BlockDesc& block,
for (auto& name_op_idx_pair : var_op_idx_map) { for (auto& name_op_idx_pair : var_op_idx_map) {
auto& name = name_op_idx_pair.first; auto& name = name_op_idx_pair.first;
size_t op_idx = name_op_idx_pair.second; size_t op_idx = name_op_idx_pair.second;
auto op = ops[op_idx].get();
result[ops[op_idx].get()].emplace_back(name); result[op].emplace_back(name);
VLOG(4) << ops[op_idx].get()->Type() << " " << name; VLOG(4) << op->Type() << " " << name;
} }
VLOG(4) << "gc map size:" << result.size(); VLOG(4) << "gc map size:" << result.size();
return result; return result;
} }
void build_variable_scope(const framework::BlockDesc& block, void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope, VariableScope* var_scope,
bool use_local_scope) { bool use_local_scope) {
VLOG(3) << "Creating Variables"; VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope(); auto inner_scope = var_scope->GetMutableScope();
...@@ -214,8 +214,8 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -214,8 +214,8 @@ void build_variable_scope(const framework::BlockDesc& block,
} }
} }
void create_all_ops(const framework::BlockDesc& block, void CreateAllOps(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) { std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
auto op_type = op->Type(); auto op_type = op->Type();
VLOG(8) << "CreateOp from : " << op_type; VLOG(8) << "CreateOp from : " << op_type;
...@@ -289,9 +289,9 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap( ...@@ -289,9 +289,9 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
return std::make_tuple(name2var, name2id); return std::make_tuple(name2var, name2id);
} }
void apply_device_guard(const OperatorBase* op_base, void ApplyDeviceGuard(const OperatorBase* op_base,
const platform::Place& place, const platform::Place& place,
OpKernelType* expected_kernel_key) { OpKernelType* expected_kernel_key) {
bool need_change_place = bool need_change_place =
(op_base->HasAttr("op_device") && (op_base->HasAttr("op_device") &&
(op_base->Attr<std::string>("op_device").length() > 0)); (op_base->Attr<std::string>("op_device").length() > 0));
...@@ -352,7 +352,7 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -352,7 +352,7 @@ void apply_device_guard(const OperatorBase* op_base,
} }
} }
void deal_operator_base(const platform::Place& place, void HandleOperatorBase(const platform::Place& place,
const VariableScope* var_scope, const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base, std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node, OpFuncNode* op_func_node,
...@@ -361,7 +361,7 @@ void deal_operator_base(const platform::Place& place, ...@@ -361,7 +361,7 @@ void deal_operator_base(const platform::Place& place,
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes. // input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base; op_func_node->operator_base_ = op_base;
if (IsSupportedHetePlace(place)) { if (IsSupportedHeterPlace(place)) {
op_func_node->type_ = OpFuncType::kQueueAsync; op_func_node->type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
op_func_node->type_ = OpFuncType::kQueueSync; op_func_node->type_ = OpFuncType::kQueueSync;
...@@ -382,19 +382,19 @@ void deal_operator_base(const platform::Place& place, ...@@ -382,19 +382,19 @@ void deal_operator_base(const platform::Place& place,
op_func_node->dev_ctx_ = dev_ctx; op_func_node->dev_ctx_ = dev_ctx;
} }
void build_op_func_list(const platform::Place& place, void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, VariableScope* var_scope,
bool use_local_scope, bool use_local_scope,
bool used_for_jit) { bool used_for_jit) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope(); : var_scope->GetMutableScope();
std::vector<std::unique_ptr<OperatorBase>> std::vector<std::unique_ptr<OperatorBase>>
ops_unique; // its elements will be moved to vec_func_list ops_unique; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block. // Step 1: create all ops for current block.
create_all_ops(block, &ops_unique); CreateAllOps(block, &ops_unique);
if (!used_for_jit) { if (!used_for_jit) {
// If gc is enabled and block size > 1 // If gc is enabled and block size > 1
...@@ -415,7 +415,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -415,7 +415,7 @@ void build_op_func_list(const platform::Place& place,
for (auto& op_unique : ops_unique) { for (auto& op_unique : ops_unique) {
ops.emplace_back(std::move(op_unique)); ops.emplace_back(std::move(op_unique));
} }
auto unused_var_map = get_unused_vars(block, ops); auto unused_var_map = GetUnusedVars(block, ops);
bool flag_log_is_printed = false; bool flag_log_is_printed = false;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
...@@ -485,10 +485,10 @@ void build_op_func_list(const platform::Place& place, ...@@ -485,10 +485,10 @@ void build_op_func_list(const platform::Place& place,
try { try {
if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) { if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
VLOG(4) << "HandleOperatorBase";
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base( HandleOperatorBase(
place, var_scope, ops[i], &op_func_node, local_scope); place, var_scope, ops[i], &op_func_node, local_scope);
VLOG(4) << "deal_operator_base";
} else { } else {
VLOG(4) << "OP is not null"; VLOG(4) << "OP is not null";
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>( auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
...@@ -522,7 +522,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -522,7 +522,7 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->GetExpectedKernelType(exec_ctx); op_with_kernel->GetExpectedKernelType(exec_ctx);
VLOG(4) << "get expected_kernel_key"; VLOG(4) << "get expected_kernel_key";
// change device by the device_guard() // change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key); ApplyDeviceGuard(op, place, &expected_kernel_key);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key; VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel // step 2. select op kernel
...@@ -565,7 +565,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -565,7 +565,7 @@ void build_op_func_list(const platform::Place& place,
dev_ctx = pool.Get(kernel_type.place_); dev_ctx = pool.Get(kernel_type.place_);
} }
op_func_node.dev_ctx_ = dev_ctx; op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHetePlace(kernel_type.place_)) { if (IsSupportedHeterPlace(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync; op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(kernel_type.place_)) { } else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync; op_func_node.type_ = OpFuncType::kQueueSync;
...@@ -667,7 +667,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -667,7 +667,7 @@ void build_op_func_list(const platform::Place& place,
vec_func_list->emplace_back(op_func_node); vec_func_list->emplace_back(op_func_node);
// gc--------------------------------------------------------------------------- // gc---------------------------------------------
auto iter = unused_var_map.find(op); auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) { if (iter == unused_var_map.end()) {
interpreter::LogDeviceMemoryStats(place); interpreter::LogDeviceMemoryStats(place);
...@@ -702,8 +702,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -702,8 +702,8 @@ void build_op_func_list(const platform::Place& place,
memory::Release(place); memory::Release(place);
} }
void add_fetch(const std::vector<std::string>& fetch_names, void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block) { framework::BlockDesc* block) {
auto* fetch_holder = block->Var(kFetchVarName); auto* fetch_holder = block->Var(kFetchVarName);
fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
...@@ -721,20 +721,6 @@ void add_fetch(const std::vector<std::string>& fetch_names, ...@@ -721,20 +721,6 @@ void add_fetch(const std::vector<std::string>& fetch_names,
} }
} }
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second) {
std::vector<size_t> out(first.size() + second.size());
std::merge(
first.begin(), first.end(), second.begin(), second.end(), out.begin());
std::vector<size_t>::iterator it;
it = std::unique(out.begin(), out.end());
out.resize(std::distance(out.begin(), it));
return out;
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -66,23 +66,20 @@ class AsyncWorkQueue { ...@@ -66,23 +66,20 @@ class AsyncWorkQueue {
void LogDeviceMemoryStats(const platform::Place& place); void LogDeviceMemoryStats(const platform::Place& place);
void build_variable_scope(const framework::BlockDesc& block, void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope, VariableScope* var_scope,
bool use_local_scope = true); bool use_local_scope = true);
void build_op_func_list(const platform::Place& place, void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* scope, VariableScope* scope,
bool use_local_scope = true, bool use_local_scope = true,
bool used_for_jit = false); bool used_for_jit = false);
void add_fetch(const std::vector<std::string>& fetch_names, void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block); framework::BlockDesc* block);
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second);
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
......
...@@ -392,7 +392,7 @@ static bool IsCpuOp(const Instruction& instr) { ...@@ -392,7 +392,7 @@ static bool IsCpuOp(const Instruction& instr) {
} }
// is supported heterogeneous place // is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) { static bool IsSupportedHeterPlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_npu_place(place) || return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place) || platform::is_ipu_place(place) || platform::is_xpu_place(place) || platform::is_ipu_place(place) ||
platform::is_custom_place(place); platform::is_custom_place(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册