未验证 提交 9516108a 编写于 作者: A Aurelius84 提交者: GitHub

Modify Struct into Class to improve encapsulation and Polish code exception (#36797)

* Refactor InterpreterCore code

* make tuple
上级 a7d8837b
......@@ -22,13 +22,13 @@ void EventManager::WaitEvent(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
VLOG(3) << "Deal StreamWaitEventOrSync for "
<< instruction.kernel_func_.operator_base_->Type();
VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type();
for (auto& event_iter : instruction.intput_events_) {
for (auto& event_iter : instruction.InputEvents()) {
VLOG(3) << "wait var_id: " << event_iter.var_id_
<< " 's event with waiter_type: " << event_iter.waiter_type_;
event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_);
event_iter.event_->Wait(event_iter.waiter_type_,
&instruction.DeviceContext());
}
}
......@@ -37,9 +37,9 @@ void EventManager::RecordEvent(const Instruction& instruction,
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
for (auto& event : instruction.output_events_) {
for (auto& event : instruction.OutputEvents()) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(instruction.dev_ctx_);
event.event_->Record(&instruction.DeviceContext());
}
}
......
......@@ -79,11 +79,9 @@ paddle::framework::FetchList InterpreterCore::Run(
const std::vector<framework::Tensor>& feed_tensors) {
auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) {
auto it = global_scope_->name2id.find(feed_names_[i]);
assert(it != global_scope_->name2id.end());
auto* feed_var = global_scope_->Var(feed_names_[i]);
auto feed_tensor = global_scope_->var_list[it->second]
->GetMutable<framework::LoDTensor>();
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
}
};
......@@ -93,7 +91,7 @@ paddle::framework::FetchList InterpreterCore::Run(
global_scope_);
FeedInput();
paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &op_list_, &vec_func_list_, global_scope_);
place_, main_program_, &vec_func_list_, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
......@@ -103,42 +101,39 @@ paddle::framework::FetchList InterpreterCore::Run(
}
// return Fetch Tensors
return *(global_scope_->var_list[global_scope_->name2id["fetch_vars"]]
->GetMutable<framework::FetchList>());
auto* fetch_var = global_scope_->Var("fetch_vars");
return *(fetch_var->GetMutable<framework::FetchList>());
}
void InterpreterCore::Convert() {
input_var2op_info_.resize(global_scope_->var_list.size());
vec_instruction_.reserve(vec_func_list_.size());
dependecy_count_.resize(vec_func_list_.size());
vec_meta_info_.resize(global_scope_->var_list.size());
for (size_t i = 0; i < vec_func_list_.size(); ++i) {
Instruction temp_inst;
auto* op_base = op_list_[i];
temp_inst.dev_ctx_ =
stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base);
temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_;
temp_inst.kernel_func_.operator_base_ = op_base;
temp_inst.input_index_ = vec_func_list_[i].input_index;
temp_inst.output_index_ = vec_func_list_[i].output_index;
temp_inst.type_ = vec_func_list_[i].type_;
temp_inst.no_data_transform_index_ =
vec_func_list_[i].no_data_transform_index;
auto var_nums = global_scope_->VarSize();
input_var2op_info_.resize(var_nums);
vec_meta_info_.resize(var_nums);
OpInOutInfo info;
auto op_nums = vec_func_list_.size();
vec_instruction_.reserve(op_nums);
dependecy_count_.resize(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = vec_func_list_[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_);
auto& instr = vec_instruction_.back();
OpInOutInfo info;
std::vector<size_t> gc_check_input_list;
for (auto& item : vec_func_list_[i].input_index) {
for (auto& item : op_func_node.input_index) {
for (auto id : item.second) {
input_var2op_info_[id].push_back(i);
input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed
if (!info.IsBuilt()) {
info.Build(op_list_[i]);
info.Build(op_func_node.operator_base_);
}
if (global_scope_->vec_meta_info_[id].vardesc_) {
if (info.IsInArgBufferNeeded(
global_scope_->vec_meta_info_[id].vardesc_->Name())) {
auto* var_desc = global_scope_->VarDesc(id);
if (var_desc) {
if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_input_list.push_back(id);
}
} else {
......@@ -150,22 +145,20 @@ void InterpreterCore::Convert() {
auto last =
std::unique(gc_check_input_list.begin(), gc_check_input_list.end());
gc_check_input_list.erase(last, gc_check_input_list.end());
for (auto var_id : gc_check_input_list) {
vec_meta_info_[var_id].var_ref_count_++;
instr.AddGCCheckVar(var_id);
}
temp_inst.gc_check_var_list.swap(gc_check_input_list);
vec_instruction_.push_back(temp_inst);
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// checkout ouput
for (auto& item : vec_instruction_[i].output_index_) {
for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) {
if (input_var2op_info_[id].size() == 0) {
if (input_var2op_info_.at(id).size() == 0) {
// output var not be used by any kernel
vec_instruction_[i].gc_check_var_list.push_back(id);
vec_instruction_[i].AddGCCheckVar(id);
vec_meta_info_[id].var_ref_count_++;
}
}
......@@ -174,7 +167,7 @@ void InterpreterCore::Convert() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].output_index_) {
for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) {
vec_temp =
interpretercore::merge_vector(vec_temp, input_var2op_info_[id]);
......@@ -205,7 +198,7 @@ void InterpreterCore::Convert() {
BuildSkipShareLoDInfo();
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
gc_event_.emplace_back(vec_instruction_[i].execution_ctx_.get()->GetPlace(),
gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
}
......@@ -215,15 +208,14 @@ void InterpreterCore::Convert() {
}
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->vec_meta_info_[var_index].vardesc_) {
return input_var2op_info_[var_index].size() == 1;
if (!global_scope_->VarDesc(var_index)) {
return input_var2op_info_.at(var_index).size() == 1;
} else {
int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_[var_index]) {
for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info;
info.Build(vec_instruction_[inst_id].kernel_func_.operator_base_);
if (info.IsInArgBufferNeeded(
global_scope_->vec_meta_info_[var_index].vardesc_->Name())) {
info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) {
is_input_cnt++;
}
}
......@@ -233,35 +225,31 @@ bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
void InterpreterCore::BuildInplace() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
if (!vec_instruction_[i]
.kernel_func_.operator_base_->Info()
.infer_inplace_) {
auto& instr = vec_instruction_[i];
auto* op_base = instr.OpBase();
if (!op_base->Info().infer_inplace_) {
continue;
}
auto in_to_outs =
vec_instruction_[i].kernel_func_.operator_base_->Info().infer_inplace_(
platform::is_gpu_place(vec_instruction_[i].dev_ctx_->GetPlace()));
auto in_to_outs = op_base->Info().infer_inplace_(
platform::is_gpu_place(instr.DeviceContext().GetPlace()));
auto& inputs = instr.Inputs();
auto& outputs = instr.Outputs();
for (auto& pair : in_to_outs) {
auto iter = vec_instruction_[i].input_index_.find(pair.first);
if (iter != vec_instruction_[i].input_index_.end()) {
auto iter = inputs.find(pair.first);
if (iter != inputs.end()) {
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = vec_instruction_[i].output_index_.find(pair.second);
if (iterout != vec_instruction_[i].output_index_.end()) {
auto invar = global_scope_->var_list[iter->second[0]];
auto outvar = global_scope_->var_list[iterout->second[0]];
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end()) {
auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar) {
vec_instruction_[i].vec_inplace_in_to_out_.emplace_back(invar,
outvar);
VLOG(3) << "inplace "
<< vec_instruction_[i].kernel_func_.operator_base_->Type()
<< " "
<< global_scope_->vec_meta_info_[iter->second[0]]
.vardesc_->Name()
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << op_base->Type() << " "
<< global_scope_->VarDesc(iter->second[0])->Name()
<< " -> "
<< global_scope_->vec_meta_info_[iterout->second[0]]
.vardesc_->Name()
<< global_scope_->VarDesc(iterout->second[0])->Name()
<< std::endl;
}
}
......@@ -274,48 +262,35 @@ void InterpreterCore::BuildInplace() {
void InterpreterCore::BuildAndCacheInstructionCtx(
Instruction* instr_node, const VariableScope& var_scope,
const platform::Place& place) {
auto op_base = instr_node->kernel_func_.operator_base_;
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->input_index_) {
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
input_vars.emplace_back(var_scope.var_list[id]);
input_vars.emplace_back(var_scope.Var(id));
}
ins_map.emplace(var_name_item.first, std::move(input_vars));
}
VariableValueMap outs_map;
for (auto& var_name_item : instr_node->output_index_) {
for (auto& var_name_item : instr_node->Outputs()) {
std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
out_vars.emplace_back(var_scope.var_list[id]);
out_vars.emplace_back(var_scope.Var(id));
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
instr_node->runtime_ctx_.reset(new RuntimeContext({}, {}));
instr_node->runtime_ctx_->inputs.swap(ins_map);
instr_node->runtime_ctx_->outputs.swap(outs_map);
instr_node->infershape_ctx_.reset(new InterpretercoreInferShapeContext(
*op_base, *instr_node->runtime_ctx_.get()));
auto* dev_ctx = instr_node->dev_ctx_;
Scope scope;
instr_node->execution_ctx_.reset(new ExecutionContext(
*op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get()));
// set runtime_ctx and infershape_ctx_
instr_node->ResetContext(ins_map, outs_map);
}
void InterpreterCore::BuildSkipShareLoDInfo() {
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
bool can_skip_lod = true;
for (auto& input : vec_instruction_[i].runtime_ctx_.get()->inputs) {
for (auto& input : vec_instruction_[i].InnerRuntimeContext()->inputs) {
for (auto& var : input.second) {
if (var->IsType<LoDTensor>()) {
if (var->Get<LoDTensor>().lod().size() != 0) {
......@@ -328,23 +303,21 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
}
}
}
vec_instruction_[i].infershape_ctx_.get()->SetSkipLoD(can_skip_lod);
vec_instruction_[i].InnerInferShapeContext()->SetSkipLoD(can_skip_lod);
}
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(3) << "RunInstruction: "
<< instr_node.kernel_func_.operator_base_->Type();
VLOG(3) << "RunInstruction: " << instr_node.OpBase()->Type();
{
platform::RecordEvent infershape_event("InferShape");
static_cast<const framework::OperatorWithKernel*>(
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());
static_cast<const framework::OperatorWithKernel*>(instr_node.OpBase())
->InferShape(instr_node.InnerInferShapeContext().get());
}
if (FLAGS_new_executor_use_inplace) {
for (auto& pair : instr_node.vec_inplace_in_to_out_) {
for (auto& pair : instr_node.InplaceInfo()) {
const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
auto* out =
paddle::framework::details::GetMutableTensorFromVar(pair.second);
......@@ -355,7 +328,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
{
platform::RecordEvent compute_event("Compute");
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}
......@@ -369,7 +342,7 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
async_work_queue_.AddTask(vec_instr[i].type_,
async_work_queue_.AddTask(vec_instr.at(i).KernelType(),
[&, i] { RunInstructionAsync(i); });
}
}
......@@ -391,43 +364,43 @@ void InterpreterCore::ExecuteInstructionList(
void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
auto& next_instr = instr.next_instruction_;
auto& next_instr = instr.NextInstructions();
auto& atomic_deps = async_work_queue_.AtomicDeps();
auto IsReady = [&](size_t next_id) {
return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
};
if (instr.type_ == OpFuncType::kQueueAsync) {
if (instr.KernelType() == OpFuncType::kQueueAsync) {
// move all sync_ops into other threads
for (auto next_id : next_instr.synchronize_run_) {
for (auto next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) {
async_work_queue_.AddTask(
vec_instruction_[next_id].type_,
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
}
}
// keep all async_ops running in current thread
for (auto next_id : next_instr.direct_run_) {
for (auto next_id : next_instr.DirectRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
}
}
for (auto next_id : next_instr.event_wait_run_) {
for (auto next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
}
}
} else {
// move async_ops into async_thread
for (auto next_id : next_instr.event_wait_run_) {
for (auto next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) {
async_work_queue_.AddTask(
vec_instruction_[next_id].type_,
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
}
}
auto direct_run_ops = interpretercore::merge_vector(
next_instr.synchronize_run_, next_instr.direct_run_);
next_instr.SyncRunIds(), next_instr.DirectRunIds());
size_t first_op = 0;
for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) {
......@@ -438,7 +411,7 @@ void InterpreterCore::RunNextInstructions(
}
// move rest ops into other threads
async_work_queue_.AddTask(
vec_instruction_[next_id].type_,
vec_instruction_[next_id].KernelType(),
[&, next_id] { RunInstructionAsync(next_id); });
}
}
......@@ -452,8 +425,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
while (!ready_ops.empty()) {
instr_id = ready_ops.front();
ready_ops.pop();
auto& instr_node = vec_instruction_[instr_id];
auto* op = instr_node.kernel_func_.operator_base_;
auto& instr_node = vec_instruction_.at(instr_id);
auto* op = instr_node.OpBase();
platform::RecordEvent instruction_event(op->Type());
event_manager_.WaitEvent(instr_node, place_);
......@@ -486,28 +459,27 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
op_run_number_.fetch_add(1, std::memory_order_relaxed);
// GC infomation
CheckGC(instr_id, instr_node.gc_check_var_list);
CheckGC(instr_node);
RunNextInstructions(instr_node, &ready_ops);
}
}
void InterpreterCore::CheckGC(size_t instr_id,
const std::vector<size_t>& gc_check_list) {
void InterpreterCore::CheckGC(const Instruction& instr) {
size_t instr_id = instr.Id();
auto& var_scope = *global_scope_;
auto& atomic_var_ref = async_work_queue_.AtomicVarRef();
for (auto var_id : gc_check_list) {
for (auto var_id : instr.GCCheckVars()) {
bool is_ready =
atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1;
if (is_ready && var_scope.vec_meta_info_[var_id].vardesc_ &&
!var_scope.vec_meta_info_[var_id].vardesc_->Persistable()) {
gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id],
vec_instruction_[instr_id].dev_ctx_);
} else if (is_ready &&
var_scope.vec_meta_info_[var_id].vardesc_ == nullptr) {
gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id],
vec_instruction_[instr_id].dev_ctx_);
if (is_ready && var_scope.VarDesc(var_id) &&
!var_scope.VarDesc(var_id)->Persistable()) {
gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id),
&instr.DeviceContext());
} else if (is_ready && var_scope.VarDesc(var_id) == nullptr) {
gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id),
&instr.DeviceContext());
}
}
}
......@@ -516,11 +488,11 @@ void InterpreterCore::DryRunPrepare(
const std::vector<framework::Tensor>& feed_tensors) {
auto FeedInput = [&] {
for (size_t i = 0; i < feed_names_.size(); ++i) {
auto it = global_scope_->name2id.find(feed_names_[i]);
assert(it != global_scope_->name2id.end());
auto* feed_var = global_scope_->FindVar(feed_names_[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
"feed_var shall not be nullptr."));
auto feed_tensor = global_scope_->var_list[it->second]
->GetMutable<framework::LoDTensor>();
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
}
};
......@@ -530,7 +502,7 @@ void InterpreterCore::DryRunPrepare(
global_scope_);
FeedInput();
paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &op_list_, &vec_func_list_, global_scope_);
place_, main_program_, &vec_func_list_, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
......
......@@ -67,7 +67,7 @@ class InterpreterCore {
void DryRunPrepare(const std::vector<framework::Tensor>& feed_tensors);
void CheckGC(size_t instr_id, const std::vector<size_t>& gc_check_list);
void CheckGC(const Instruction& instr);
void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id,
......@@ -82,16 +82,15 @@ class InterpreterCore {
ProgramDesc main_program_;
VariableScope* global_scope_;
std::vector<Instruction> vec_instruction_;
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_;
std::vector<std::vector<size_t>> input_var2op_info_;
std::vector<VariableMetaInfo> ref_coun_info_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<paddle::framework::OperatorBase*> op_list_;
std::vector<std::string> feed_names_;
InterpreterProfiler dry_run_profiler_;
......
......@@ -19,6 +19,7 @@
namespace paddle {
namespace framework {
namespace interpretercore {
using VariableIdMap = std::map<std::string, std::vector<int>>;
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
const std::vector<size_t>& dependecy_count) {
......@@ -132,43 +133,29 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
for (auto& var_desc : global_block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) {
var_scope->name2id[var->Name()] = var_scope->var_list.size();
auto v = new Variable();
InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var;
var_scope->vec_meta_info_.push_back(info);
if (nullptr == var_scope->FindVar(var_name)) {
var_scope->AddVar(var_desc->Name(), var_desc);
} else {
auto var_id = var_scope->name2id[var->Name()];
if (nullptr == var_scope->vec_meta_info_[var_id].vardesc_) {
VLOG(3) << "update var:" << var->Name() << " desc from nullptr into "
<< var;
var_scope->vec_meta_info_[var_id].vardesc_ = var;
auto* var_desc = var_scope->VarDesc(var_name);
if (nullptr == var_desc) {
VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var_desc;
var_scope->VarMetaInfo(var_name).vardesc_ = var_desc;
}
}
}
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OperatorBase*>* op_list,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
std::vector<OperatorBase*> ops;
for (auto& op : global_block.AllOps()) {
VLOG(3) << "Build OpFuncNode from : " << op->Type();
for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type();
auto& info = OpInfoMap::Instance().Get(op->Type());
......@@ -179,64 +166,96 @@ void build_op_func_list(const platform::Place& place,
if (info.Checker() != nullptr) {
info.Checker()->Check(&op_attr_map);
}
// step 1. Prepare VariableValueMap of input/output
auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
ops.push_back(op_base);
}
return ops;
}
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map, VariableScope* var_scope) {
VariableValueMap name2var;
VariableIdMap name2id;
for (auto& item : var_name_map) {
std::vector<Variable*> vars;
std::vector<int> ids;
vars.reserve(item.second.size());
for (auto& var_name : item.second) {
auto var_id = var_scope->VarId(var_name);
auto* in_var = var_scope->Var(var_id);
vars.push_back(in_var);
ids.push_back(var_id);
}
name2var[item.first] = std::move(vars);
name2id[item.first] = std::move(ids);
}
return std::make_tuple(name2var, name2id);
}
void apply_device_guard(const OperatorBase* op_base,
const platform::Place& place,
OpKernelType* expected_kernel_key) {
bool need_change_place =
(op_base->HasAttr("op_device") &&
(op_base->Attr<std::string>("op_device").length() > 0));
if (need_change_place) {
auto& op_device = op_base->Attr<std::string>("op_device");
if (op_device == "cpu" || platform::is_cpu_place(place)) {
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key->place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) {
VLOG(3) << "Switch into " << place << " by device_guard.";
expected_kernel_key->place_ = place;
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
}
}
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block.
auto ops = create_all_ops(global_block);
auto unused_var_map = get_unused_vars(global_block, ops);
size_t ops_index = 0;
for (auto& op : global_block.AllOps()) {
VLOG(3) << op->Type();
// << op->Type() << endl;
VLOG(3) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs();
VariableValueMap ins_map;
std::map<std::string, std::vector<int>> ins_name2id;
for (auto& var_name_item : inputs_names) {
std::vector<Variable*> input_vars;
std::vector<int> vec_ids;
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find(var_name);
assert(it != var_scope->name2id.end());
input_vars.push_back(var_scope->var_list[it->second]);
vec_ids.push_back(it->second);
}
ins_map[var_name_item.first] = input_vars;
ins_name2id[var_name_item.first] = vec_ids;
}
VariableIdMap ins_name2id;
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope);
VariableValueMap outs_map;
std::map<std::string, std::vector<int>> outs_name2id;
for (auto& var_name_item : outputs_names) {
std::vector<Variable*> output_vars;
std::vector<int> vec_ids;
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) {
auto it = var_scope->name2id.find(var_name);
assert(it != var_scope->name2id.end());
output_vars.push_back(var_scope->var_list[it->second]);
vec_ids.push_back(it->second);
}
outs_map[var_name_item.first] = output_vars;
outs_name2id[var_name_item.first] = vec_ids;
}
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope);
// step 2: build OpFuncNode
OpFuncNode op_func_node;
op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id;
// step 2: construct RuntimeContext and analysis KernelType
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type());
......@@ -256,32 +275,18 @@ void build_op_func_list(const platform::Place& place,
->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context));
// consider device_guard context
bool need_change_place =
(op_base->HasAttr("op_device") &&
(op_base->Attr<std::string>("op_device").length() > 0));
if (need_change_place) {
auto& op_device = op_base->Attr<std::string>("op_device");
if (op_device == "cpu" || platform::is_cpu_place(place)) {
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key.place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) {
VLOG(3) << "Switch into " << place << " by device_guard.";
expected_kernel_key.place_ = place;
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
}
}
// consider device_guard()
apply_device_guard(op_base, place, &expected_kernel_key);
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. Insert memcpy_op if needed
VariableValueMap& ins_map_temp = runtime_context.inputs;
std::unordered_set<int> no_data_transform_index;
for (auto& var_name_item : ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i);
auto tensor_in = static_cast<const Tensor*>(&(var->Get<LoDTensor>()));
if (!tensor_in->IsInitialized()) {
continue;
......@@ -293,32 +298,19 @@ void build_op_func_list(const platform::Place& place,
if (platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) {
// record no need data transformer input var_id
auto& var_name = inputs_names[var_name_item.first][i];
VLOG(3) << op->Type() << " found no data_transform var: " << var_name
<< " with id: " << var_scope->name2id[var_name];
no_data_transform_index.emplace(var_scope->name2id[var_name]);
<< " with id: " << var_name;
no_data_transform_index.emplace(var_scope->VarId(var_name));
} else {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
}
// need trans place
// 1. add var in scope
// 2. add copy op
std::string new_var_name =
"temp_1" + std::to_string(var_scope->var_list.size() + 1);
auto v = new Variable();
v->GetMutable<LoDTensor>();
var_scope->name2id[new_var_name] = var_scope->var_list.size();
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
var_scope->vec_meta_info_.push_back(info);
var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
var_scope->AddVar(new_var_name, nullptr);
VariableNameMap copy_in_map;
auto x_iter = inputs_names.find(var_name_item.first);
copy_in_map["X"] = {x_iter->second[i]};
copy_in_map["X"] = {var_name};
VariableNameMap copy_out_map;
copy_out_map["Out"] = {new_var_name};
AttributeMap attr_map;
......@@ -328,23 +320,23 @@ void build_op_func_list(const platform::Place& place,
: is_gpu_place(expected_kernel_key.place_) ? 1 : -1;
std::map<std::string, std::vector<int>> copy_ins_name2id;
copy_ins_name2id["X"] = ins_name2id[var_name_item.first];
copy_ins_name2id["X"] = ins_name2id.at(var_name_item.first);
std::map<std::string, std::vector<int>> copy_out_name2id;
copy_out_name2id["Out"] = {var_scope->name2id[new_var_name]};
copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)};
op_func_node.input_index[var_name_item.first][i] =
var_scope->name2id[new_var_name];
var_scope->VarId(new_var_name);
VariableValueMap copy_ins_value_map;
copy_ins_value_map["X"] = {var};
VariableValueMap copy_outs_value_map;
copy_outs_value_map["Out"] = {v};
copy_outs_value_map["Out"] = {var_scope->Var(new_var_name)};
// memcpy_d2h, memcpy_h2d
auto memcpy_op_type = get_memcpy_type(kernel_type_for_var.place_,
expected_kernel_key.place_);
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).",
memcpy_op_type, x_iter->second[i],
memcpy_op_type, var_name,
kernel_type_for_var.place_, new_var_name,
expected_kernel_key.place_);
auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type);
......@@ -385,16 +377,16 @@ void build_op_func_list(const platform::Place& place,
// as kQueueSync and execute them in thread pool.
copy_op_func_node.type_ = OpFuncType::kQueueSync;
copy_op_func_node.dev_ctx_ = dev_ctx;
op_list->push_back(copy_op);
copy_op_func_node.operator_base_ = copy_op;
vec_func_list->push_back(copy_op_func_node);
var_name_item.second[i] = v;
var_name_item.second[i] = var_scope->Var(new_var_name);
}
}
}
op_func_node.no_data_transform_index = std::move(no_data_transform_index);
// step 4. Run op kernel
op_list->push_back(op_base);
op_func_node.operator_base_ = op_base;
VLOG(3) << op_base->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
......@@ -436,9 +428,7 @@ void build_op_func_list(const platform::Place& place,
new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) {
auto it = var_scope->name2id.find(var_name);
assert(it != var_scope->name2id.end());
auto* var = var_scope->var_list[it->second];
auto* var = var_scope->FindVar(var_name);
if (var == nullptr) {
continue;
}
......
......@@ -101,7 +101,6 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OperatorBase*>* op_list,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope);
......
......@@ -19,6 +19,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
......@@ -463,7 +464,6 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
struct OpKernelFunc {
OpKernelComputeFunc compute_func_;
OperatorBase* operator_base_;
};
struct VariableMetaInfo {
......@@ -471,13 +471,108 @@ struct VariableMetaInfo {
paddle::framework::VarDesc* vardesc_;
};
struct VariableScope {
// TODO(Aurelius84): Consider inherit ScopeBase to unify interface.
class VariableScope {
public:
Variable* FindVar(const std::string& name) const {
if (!HasVar(name)) {
return nullptr;
}
auto var_id = VarId(name);
CheckExist(var_id);
return var_list[var_id];
}
bool HasVar(const std::string& name) const {
return name2id.find(name) != name2id.end();
}
int VarId(const std::string& name) const {
CheckExist(name);
return name2id.at(name);
}
Variable* Var(int id) const { return var_list.at(id); }
Variable* Var(const std::string& name) const {
return var_list.at(VarId(name));
}
size_t VarSize() const { return var_list.size(); }
void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT
name2id[name] = VarSize();
auto v = new Variable();
if (nullptr == var_desc) {
v->GetMutable<LoDTensor>();
} else {
InitializeVariable(v, var_desc->GetType());
}
var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var_desc;
vec_meta_info_.push_back(info);
}
void AddVar(const std::string& name, Variable& var) { // NOLINT
name2id[name] = VarSize();
var_list.push_back(&var);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
vec_meta_info_.push_back(info);
}
paddle::framework::VarDesc* VarDesc(const std::string& name) const {
return VarDesc(VarId(name));
}
paddle::framework::VarDesc* VarDesc(int id) const {
CheckExist(id);
return vec_meta_info_[id].vardesc_;
}
VariableMetaInfo& VarMetaInfo(const std::string& name) {
return vec_meta_info_[VarId(name)];
}
void CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, var_list.size(),
platform::errors::PreconditionNotMet(
"Required var_id < %d, but received var_id = %d.",
var_list.size(), id));
}
void CheckExist(const std::string& name) const {
PADDLE_ENFORCE_EQ(
HasVar(name), true,
platform::errors::NotFound("%s not in VariableScope.", name));
}
private:
std::vector<Variable*> var_list;
std::map<std::string, int> name2id;
std::vector<VariableMetaInfo> vec_meta_info_;
};
struct NextInstruction {
class NextInstruction {
public:
void AddDirectRun(size_t id) { direct_run_.push_back(id); }
void ADDEventRun(size_t id) { event_wait_run_.push_back(id); }
void AddSyncRun(size_t id) { synchronize_run_.push_back(id); }
const std::vector<size_t>& DirectRunIds() const { return direct_run_; }
const std::vector<size_t>& EventRunIds() const { return event_wait_run_; }
const std::vector<size_t>& SyncRunIds() const { return synchronize_run_; }
private:
std::vector<size_t> direct_run_;
std::vector<size_t> event_wait_run_;
std::vector<size_t> synchronize_run_;
......@@ -503,49 +598,138 @@ enum class OpFuncType {
};
class RuntimeInferShapeContext;
struct Instruction {
OpKernelFunc kernel_func_;
struct OpFuncNode {
OperatorBase* operator_base_;
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
};
class Instruction {
public:
Instruction(size_t id, const OpFuncNode& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
size_t Id() const { return id_; }
const std::map<std::string, std::vector<int>>& Inputs() const {
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Outputs() const {
return op_func_node_.output_index;
}
const std::unordered_set<int>& NoDataTransformVars() const {
return op_func_node_.no_data_transform_index;
}
OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; }
OpFuncType KernelType() const { return op_func_node_.type_; }
OperatorBase* OpBase() const {
auto* op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base;
}
NextInstruction& NextInstructions() { return next_instruction_; }
const NextInstruction& NextInstructions() const { return next_instruction_; }
void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
const std::vector<size_t>& GCCheckVars() const { return gc_check_var_list_; }
void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
const {
return infershape_ctx_;
}
std::shared_ptr<ExecutionContext> InnerExecutionContext() const {
return execution_ctx_;
}
const platform::DeviceContext& DeviceContext() const { return dev_ctx_; }
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const {
return vec_inplace_in_to_out_;
}
void AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
const std::vector<EventInter>& InputEvents() const { return intput_events_; }
const std::vector<EventInter>& OutputEvents() const { return output_events_; }
void AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
intput_events_.emplace_back(var_id, event, waiter_type);
}
void AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
output_events_.emplace_back(var_id, event, waiter_type);
}
private:
size_t id_;
const OpFuncNode& op_func_node_; // not owned
const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_;
std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_;
std::unordered_set<int> no_data_transform_index_;
std::vector<size_t> gc_check_var_list;
std::vector<size_t> gc_check_var_list_;
NextInstruction next_instruction_;
std::vector<EventInter> intput_events_;
std::vector<EventInter> output_events_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
};
struct OpFuncNode {
// int unsed;
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
OpFuncType type_;
};
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
static bool IsMemcpyH2D(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D;
return instr.OpBase()->Type() == kMemcpyH2D;
}
static bool IsMemcpyD2H(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H;
return instr.OpBase()->Type() == kMemcpyD2H;
}
} // namespace interpretercore
......
......@@ -33,23 +33,16 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
auto name_list = outer_scope_->LocalVarNames();
for (auto name : name_list) {
auto v = outer_scope_->Var(name);
if (global_scope_.name2id.find(name) == global_scope_.name2id.end()) {
global_scope_.name2id[name] = global_scope_.var_list.size();
global_scope_.var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
global_scope_.vec_meta_info_.push_back(info);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
}
}
}
// run startup program
std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector<paddle::framework::OperatorBase*> op_list;
paddle::framework::interpretercore::build_op_func_list(
place_, startup_prog, &op_list, &vec_func_list, &global_scope_);
place_, startup_prog, &vec_func_list, &global_scope_);
}
paddle::framework::FetchList StandaloneExecutor::Run(
......@@ -80,16 +73,8 @@ void StandaloneExecutor::BuildVariableOuterScope(
continue;
}
if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) {
var_scope->name2id[var->Name()] = var_scope->var_list.size();
auto v = outer_scope->Var(var->Name());
InitializeVariable(v, var->GetType());
var_scope->var_list.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var;
var_scope->vec_meta_info_.push_back(info);
if (!var_scope->HasVar(var->Name())) {
var_scope->AddVar(var->Name(), var);
}
}
}
......
......@@ -31,15 +31,15 @@ namespace framework {
std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
const Instruction& cur_instr, const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.output_index_) {
for (auto& item : cur_instr.Outputs()) {
unique_var_ids.insert(item.second.begin(), item.second.end());
}
std::vector<size_t> new_event_var_ids;
for (auto& item : next_instr.input_index_) {
for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0 &&
next_instr.no_data_transform_index_.count(var_id) == 0) {
next_instr.NoDataTransformVars().count(var_id) == 0) {
new_event_var_ids.push_back(var_id);
}
}
......@@ -57,8 +57,7 @@ void StreamAnalyzer::AssociateInputWithEvents(
var_id2event_.emplace(var_id, std::move(device_event));
}
// Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id),
waiter_type);
next_instr->AddInputEvent(var_id, var_id2event_.at(var_id), waiter_type);
}
}
......@@ -66,13 +65,13 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions,
size_t op_index) {
auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_;
auto& next_instruction = cur_instr.NextInstructions();
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
if (IsDirectRun(cur_instr, next_instr)) {
next_instruction.direct_run_.emplace_back(next_op_id);
next_instruction.AddDirectRun(next_op_id);
} else {
// Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
......@@ -83,24 +82,24 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type);
if (waiter_type == platform::kCPU) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id);
next_instruction.AddSyncRun(next_op_id);
} else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id);
next_instruction.ADDEventRun(next_op_id);
}
}
}
// Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
VLOG(3) << cur_instr.OpBase()->Type()
<< " event_var_ids.size: " << event_var_ids.size();
for (auto var_id : event_var_ids) {
cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id),
platform::kCUDA /*not used*/);
cur_instr.AddOutputEvent(var_id, var_id2event_.at(var_id),
platform::kCUDA /*not used*/);
}
}
platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node, const OperatorBase& op_base) {
auto& op_type = op_base.Type();
const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
......@@ -122,13 +121,13 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ ||
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpretercore::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr));
}
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
if (instr.type_ == OpFuncType::kQueueSync) {
if (instr.KernelType() == OpFuncType::kQueueSync) {
return platform::kCPU;
} else {
return platform::kCUDA;
......
......@@ -32,8 +32,7 @@ class StreamAnalyzer {
void Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions, size_t op_index);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node,
const OperatorBase& op_base);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node);
private:
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册