未验证 提交 5303b66b 编写于 作者: L Leo Chen 提交者: GitHub

clean code of interpretercore (#46891)

* refactor

* refine code
上级 21fab90d
......@@ -378,7 +378,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
"Required src_place shall be different with dst_place, "
"but received same place: %s",
src_place));
if (IsSupportedHetePlace(dst_place)) {
if (IsSupportedHeterPlace(dst_place)) {
op_type = kMemcpyH2D;
int dst_place_type = platform::is_gpu_place(dst_place) ? 0
: platform::is_npu_place(dst_place) ? 1
......@@ -387,7 +387,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
: platform::is_custom_place(dst_place) ? 6
: -1;
attr_map = {{"dst_place_type", dst_place_type}};
} else if (IsSupportedHetePlace(src_place)) {
} else if (IsSupportedHeterPlace(src_place)) {
op_type = kMemcpyD2H;
int dst_place_type = platform::is_cpu_place(dst_place) ? 0
: platform::is_cuda_pinned_place(dst_place) ? 1
......
......@@ -57,6 +57,50 @@ constexpr const char* kTaskCompletion = "TaskCompletion";
namespace paddle {
namespace framework {
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
// TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by
// constructing an interpreter::ExecutionConfig
InterpreterCore::InterpreterCore(const platform::Place& place,
......@@ -71,8 +115,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
is_build_ = false;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
......@@ -87,12 +129,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
local_scope_ = local_scope;
}
var_scope_.SetLocalScope(local_scope_);
// prune
// optmize graph pass
// convert to run graph
}
InterpreterCore::~InterpreterCore() {
......@@ -111,11 +147,8 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
SetDeviceId(place_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
......@@ -135,7 +168,7 @@ interpreter::CostInfo InterpreterCore::DryRun(
platform::DeviceContextPool::Instance().Get(place_)->Wait();
}
if (execution_config_.create_local_scope) {
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}
......@@ -145,11 +178,7 @@ interpreter::CostInfo InterpreterCore::DryRun(
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
SetDeviceId(place_);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
......@@ -181,7 +210,7 @@ paddle::framework::FetchList InterpreterCore::Run(
}
#endif
}
if (execution_config_.create_local_scope) {
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}
......@@ -196,11 +225,7 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(place_.device);
}
#endif
SetDeviceId(place_);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
......@@ -208,17 +233,17 @@ paddle::framework::FetchList InterpreterCore::Run(
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::build_variable_scope(
block_, &var_scope_, execution_config_.create_local_scope);
paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope());
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_.create_local_scope,
HasLocalScope(),
execution_config_.used_for_jit);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
......@@ -248,13 +273,13 @@ paddle::framework::FetchList InterpreterCore::Run(
#endif
}
if (execution_config_.create_local_scope) {
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors
Scope* inner_scope = execution_config_.create_local_scope
? local_scope_
: var_scope_.GetMutableScope();
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
......@@ -327,9 +352,8 @@ std::shared_ptr<interpreter::AsyncWorkQueue> InterpreterCore::GetWorkQueue() {
}
void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope = execution_config_.create_local_scope
? local_scope_
: var_scope_.GetMutableScope();
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
......@@ -355,8 +379,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
// set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = execution_config_.create_local_scope
? var_scope_.GetMutableLocalScope()
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
......@@ -387,8 +410,7 @@ void InterpreterCore::BuildInplace() {
}
}
Scope* local_scope = execution_config_.create_local_scope
? var_scope_.GetMutableLocalScope()
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) {
......@@ -524,9 +546,8 @@ void InterpreterCore::Convert(
}
for (auto var_id : gc_check_vars) {
Scope* inner_scope = execution_config_.create_local_scope
? local_scope_
: var_scope_.GetMutableScope();
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
paddle::framework::Variable* var =
inner_scope->FindVar(var_scope_.GetNameById(var_id));
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
......@@ -629,55 +650,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
}
}
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = execution_config_.create_local_scope
? var_scope_.GetMutableLocalScope()
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
......@@ -800,8 +776,8 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
void InterpreterCore::ExecuteInstructionList(
const std::vector<Instruction>& vec_instr) {
interpreter::ResetAtomicGuard guard(&deps_, &refs_);
unfinished_op_numer_ = vec_instr.size();
if (unfinished_op_numer_ == 0) {
unfinished_op_number_ = vec_instr.size();
if (unfinished_op_number_ == 0) {
VLOG(4) << "No op to run, return";
return;
}
......@@ -878,8 +854,12 @@ void InterpreterCore::RunNextInstructions(
[this, next_id] { RunInstructionAsync(next_id); });
}
}
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
next_instr.DirectRunIds());
std::vector<size_t> direct_run_ops = next_instr.SyncRunIds();
direct_run_ops.insert(direct_run_ops.end(),
next_instr.DirectRunIds().begin(),
next_instr.DirectRunIds().end());
int64_t first_op = -1;
for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) {
......@@ -949,9 +929,9 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
return;
}
VLOG(4) << "unfinished_op_numer_: " << unfinished_op_numer_;
if (UNLIKELY(unfinished_op_numer_.fetch_sub(1, std::memory_order_relaxed) ==
1)) {
VLOG(4) << "unfinished_op_number_: " << unfinished_op_number_;
if (UNLIKELY(unfinished_op_number_.fetch_sub(
1, std::memory_order_relaxed) == 1)) {
if (completion_notifier_ != nullptr) {
completion_notifier_->NotifyEvent();
}
......@@ -961,8 +941,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unimplemented(
"RecordStreamForGC is only implemented when compiled with GPU."));
#else
if (!IsInterpretercoreFastGCEnabled() ||
instr.KernelType() != OpFuncType::kQueueAsync) {
return;
......@@ -1053,8 +1036,8 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
framework::ToTypeName(var->Type())));
}
}
}
#endif
}
void InterpreterCore::CheckGC(const Instruction& instr) {
platform::RecordEvent record(
......@@ -1106,17 +1089,17 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
};
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(
block_, &var_scope_, execution_config_.create_local_scope);
paddle::framework::interpreter::BuildVariableScope(
block_, &var_scope_, HasLocalScope());
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_.create_local_scope,
HasLocalScope(),
execution_config_.used_for_jit);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
......@@ -1124,7 +1107,7 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
Convert(&op_func_nodes);
}
// NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should
// paddle::framework::BuildOpFuncList, so we should
// call FeedInput again.
if (prepare_feed) {
FeedInput();
......@@ -1138,6 +1121,8 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
}
}
bool InterpreterCore::HasLocalScope() const { return local_scope_ != nullptr; }
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
......@@ -1145,11 +1130,11 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
// NOTE(Aurelius84): `AddFetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
interpreter::AddFetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
core->SetCopyProgram(new_prog);
......
......@@ -68,45 +68,42 @@ class InterpreterCore {
void reset_scope(Scope* new_scope);
private:
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
// inplace
void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
void BuildOperatorDependences();
void ClearLoDTensorArrayInLocalScope();
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void RunInstruction(const Instruction& instr_node);
// execution
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// gc
void RecordStreamForGC(const Instruction& instr);
#endif
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
void BuildSkipShareLoDInfo();
// workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// scope
bool HasLocalScope() const;
private:
bool is_build_;
bool is_build_{false};
platform::Place place_;
const BlockDesc& block_; // not owned
......@@ -127,11 +124,7 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
// last_live_ops_[i] contains the id of operators that last access var[i]
std::map<size_t, std::set<size_t>> last_live_ops_;
std::vector<size_t> dependecy_count_;
std::atomic<size_t> unfinished_op_numer_{0};
std::atomic<size_t> unfinished_op_number_{0};
VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned
......@@ -145,8 +138,13 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
// last_live_ops_[i] contains the id of operators that last access the i-th
// var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
......
......@@ -122,7 +122,7 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
std::unordered_map<const paddle::framework::OperatorBase*,
std::vector<std::string>>
get_unused_vars(const BlockDesc& block,
GetUnusedVars(const BlockDesc& block,
const std::vector<std::shared_ptr<OperatorBase>>& ops) {
std::unordered_map<std::string, size_t> var_op_idx_map;
......@@ -166,15 +166,15 @@ get_unused_vars(const BlockDesc& block,
for (auto& name_op_idx_pair : var_op_idx_map) {
auto& name = name_op_idx_pair.first;
size_t op_idx = name_op_idx_pair.second;
result[ops[op_idx].get()].emplace_back(name);
VLOG(4) << ops[op_idx].get()->Type() << " " << name;
auto op = ops[op_idx].get();
result[op].emplace_back(name);
VLOG(4) << op->Type() << " " << name;
}
VLOG(4) << "gc map size:" << result.size();
return result;
}
void build_variable_scope(const framework::BlockDesc& block,
void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope,
bool use_local_scope) {
VLOG(3) << "Creating Variables";
......@@ -214,7 +214,7 @@ void build_variable_scope(const framework::BlockDesc& block,
}
}
void create_all_ops(const framework::BlockDesc& block,
void CreateAllOps(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) {
auto op_type = op->Type();
......@@ -289,7 +289,7 @@ std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
return std::make_tuple(name2var, name2id);
}
void apply_device_guard(const OperatorBase* op_base,
void ApplyDeviceGuard(const OperatorBase* op_base,
const platform::Place& place,
OpKernelType* expected_kernel_key) {
bool need_change_place =
......@@ -352,7 +352,7 @@ void apply_device_guard(const OperatorBase* op_base,
}
}
void deal_operator_base(const platform::Place& place,
void HandleOperatorBase(const platform::Place& place,
const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node,
......@@ -361,7 +361,7 @@ void deal_operator_base(const platform::Place& place,
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
if (IsSupportedHetePlace(place)) {
if (IsSupportedHeterPlace(place)) {
op_func_node->type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(place)) {
op_func_node->type_ = OpFuncType::kQueueSync;
......@@ -382,7 +382,7 @@ void deal_operator_base(const platform::Place& place,
op_func_node->dev_ctx_ = dev_ctx;
}
void build_op_func_list(const platform::Place& place,
void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
......@@ -394,7 +394,7 @@ void build_op_func_list(const platform::Place& place,
std::vector<std::unique_ptr<OperatorBase>>
ops_unique; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block.
create_all_ops(block, &ops_unique);
CreateAllOps(block, &ops_unique);
if (!used_for_jit) {
// If gc is enabled and block size > 1
......@@ -415,7 +415,7 @@ void build_op_func_list(const platform::Place& place,
for (auto& op_unique : ops_unique) {
ops.emplace_back(std::move(op_unique));
}
auto unused_var_map = get_unused_vars(block, ops);
auto unused_var_map = GetUnusedVars(block, ops);
bool flag_log_is_printed = false;
for (size_t i = 0; i < ops.size(); ++i) {
......@@ -485,10 +485,10 @@ void build_op_func_list(const platform::Place& place,
try {
if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
VLOG(4) << "HandleOperatorBase";
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(
HandleOperatorBase(
place, var_scope, ops[i], &op_func_node, local_scope);
VLOG(4) << "deal_operator_base";
} else {
VLOG(4) << "OP is not null";
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
......@@ -522,7 +522,7 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->GetExpectedKernelType(exec_ctx);
VLOG(4) << "get expected_kernel_key";
// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
ApplyDeviceGuard(op, place, &expected_kernel_key);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel
......@@ -565,7 +565,7 @@ void build_op_func_list(const platform::Place& place,
dev_ctx = pool.Get(kernel_type.place_);
}
op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHetePlace(kernel_type.place_)) {
if (IsSupportedHeterPlace(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
......@@ -667,7 +667,7 @@ void build_op_func_list(const platform::Place& place,
vec_func_list->emplace_back(op_func_node);
// gc---------------------------------------------------------------------------
// gc---------------------------------------------
auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) {
interpreter::LogDeviceMemoryStats(place);
......@@ -702,7 +702,7 @@ void build_op_func_list(const platform::Place& place,
memory::Release(place);
}
void add_fetch(const std::vector<std::string>& fetch_names,
void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block) {
auto* fetch_holder = block->Var(kFetchVarName);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
......@@ -721,20 +721,6 @@ void add_fetch(const std::vector<std::string>& fetch_names,
}
}
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second) {
std::vector<size_t> out(first.size() + second.size());
std::merge(
first.begin(), first.end(), second.begin(), second.end(), out.begin());
std::vector<size_t>::iterator it;
it = std::unique(out.begin(), out.end());
out.resize(std::distance(out.begin(), it));
return out;
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -66,11 +66,11 @@ class AsyncWorkQueue {
void LogDeviceMemoryStats(const platform::Place& place);
void build_variable_scope(const framework::BlockDesc& block,
void BuildVariableScope(const framework::BlockDesc& block,
VariableScope* var_scope,
bool use_local_scope = true);
void build_op_func_list(const platform::Place& place,
void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
......@@ -78,12 +78,9 @@ void build_op_func_list(const platform::Place& place,
bool use_local_scope = true,
bool used_for_jit = false);
void add_fetch(const std::vector<std::string>& fetch_names,
void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second);
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -392,7 +392,7 @@ static bool IsCpuOp(const Instruction& instr) {
}
// is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) {
static bool IsSupportedHeterPlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place) || platform::is_ipu_place(place) ||
platform::is_custom_place(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册