diff --git a/paddle/fluid/framework/new_executor/event_manager.cc b/paddle/fluid/framework/new_executor/event_manager.cc index bd83f49db1d0e3bc1b8d111c32048ebb9df5b930..87caff8c572f8c407127a344e77c2d8253ee7290 100644 --- a/paddle/fluid/framework/new_executor/event_manager.cc +++ b/paddle/fluid/framework/new_executor/event_manager.cc @@ -22,13 +22,13 @@ void EventManager::WaitEvent(const Instruction& instruction, // If InterpreterCore in on CPUPlace, do nothing. if (platform::is_cpu_place(place)) return; - VLOG(3) << "Deal StreamWaitEventOrSync for " - << instruction.kernel_func_.operator_base_->Type(); + VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.OpBase()->Type(); - for (auto& event_iter : instruction.intput_events_) { + for (auto& event_iter : instruction.InputEvents()) { VLOG(3) << "wait var_id: " << event_iter.var_id_ << " 's event with waiter_type: " << event_iter.waiter_type_; - event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_); + event_iter.event_->Wait(event_iter.waiter_type_, + &instruction.DeviceContext()); } } @@ -37,9 +37,9 @@ void EventManager::RecordEvent(const Instruction& instruction, // If InterpreterCore in on CPUPlace, do nothing. if (platform::is_cpu_place(place)) return; - for (auto& event : instruction.output_events_) { + for (auto& event : instruction.OutputEvents()) { VLOG(3) << "Record event in out_var_id: " << event.var_id_; - event.event_->Record(instruction.dev_ctx_); + event.event_->Record(&instruction.DeviceContext()); } } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index d6ea840362e7ef938dd8c12f3f248839ad39ce9e..a8976cca7c79f791d9bd6602fc64fa389af5e8f4 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -79,11 +79,9 @@ paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_tensors) { auto FeedInput = [&] { for (size_t i = 0; i < feed_names_.size(); ++i) { - auto it = global_scope_->name2id.find(feed_names_[i]); - assert(it != global_scope_->name2id.end()); + auto* feed_var = global_scope_->Var(feed_names_[i]); - auto feed_tensor = global_scope_->var_list[it->second] - ->GetMutable(); + auto feed_tensor = feed_var->GetMutable(); feed_tensor->ShareDataWith(feed_tensors[i]); } }; @@ -93,7 +91,7 @@ paddle::framework::FetchList InterpreterCore::Run( global_scope_); FeedInput(); paddle::framework::interpretercore::build_op_func_list( - place_, main_program_, &op_list_, &vec_func_list_, global_scope_); + place_, main_program_, &vec_func_list_, global_scope_); is_build_ = true; // convert vec func_list to graph Convert(); @@ -103,42 +101,39 @@ paddle::framework::FetchList InterpreterCore::Run( } // return Fetch Tensors - return *(global_scope_->var_list[global_scope_->name2id["fetch_vars"]] - ->GetMutable()); + auto* fetch_var = global_scope_->Var("fetch_vars"); + return *(fetch_var->GetMutable()); } void InterpreterCore::Convert() { - input_var2op_info_.resize(global_scope_->var_list.size()); - - vec_instruction_.reserve(vec_func_list_.size()); - dependecy_count_.resize(vec_func_list_.size()); - vec_meta_info_.resize(global_scope_->var_list.size()); - for (size_t i = 0; i < vec_func_list_.size(); ++i) { - Instruction temp_inst; - auto* op_base = op_list_[i]; - temp_inst.dev_ctx_ = - stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base); - temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_; - temp_inst.kernel_func_.operator_base_ = op_base; - temp_inst.input_index_ = vec_func_list_[i].input_index; - temp_inst.output_index_ = vec_func_list_[i].output_index; - temp_inst.type_ = vec_func_list_[i].type_; - temp_inst.no_data_transform_index_ = - vec_func_list_[i].no_data_transform_index; + auto var_nums = global_scope_->VarSize(); + input_var2op_info_.resize(var_nums); + vec_meta_info_.resize(var_nums); - OpInOutInfo info; + auto op_nums = vec_func_list_.size(); + vec_instruction_.reserve(op_nums); + dependecy_count_.resize(op_nums); + + for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { + auto& op_func_node = vec_func_list_[op_idx]; + auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); + vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_); + auto& instr = vec_instruction_.back(); + + OpInOutInfo info; std::vector gc_check_input_list; - for (auto& item : vec_func_list_[i].input_index) { + + for (auto& item : op_func_node.input_index) { for (auto id : item.second) { - input_var2op_info_[id].push_back(i); + input_var2op_info_.at(id).push_back(op_idx); // var can be gc-ed if (!info.IsBuilt()) { - info.Build(op_list_[i]); + info.Build(op_func_node.operator_base_); } - if (global_scope_->vec_meta_info_[id].vardesc_) { - if (info.IsInArgBufferNeeded( - global_scope_->vec_meta_info_[id].vardesc_->Name())) { + auto* var_desc = global_scope_->VarDesc(id); + if (var_desc) { + if (info.IsInArgBufferNeeded(var_desc->Name())) { gc_check_input_list.push_back(id); } } else { @@ -150,22 +145,20 @@ void InterpreterCore::Convert() { auto last = std::unique(gc_check_input_list.begin(), gc_check_input_list.end()); gc_check_input_list.erase(last, gc_check_input_list.end()); + for (auto var_id : gc_check_input_list) { vec_meta_info_[var_id].var_ref_count_++; + instr.AddGCCheckVar(var_id); } - - temp_inst.gc_check_var_list.swap(gc_check_input_list); - - vec_instruction_.push_back(temp_inst); } for (size_t i = 0; i < vec_instruction_.size(); ++i) { // checkout ouput - for (auto& item : vec_instruction_[i].output_index_) { + for (auto& item : vec_instruction_[i].Outputs()) { for (auto id : item.second) { - if (input_var2op_info_[id].size() == 0) { + if (input_var2op_info_.at(id).size() == 0) { // output var not be used by any kernel - vec_instruction_[i].gc_check_var_list.push_back(id); + vec_instruction_[i].AddGCCheckVar(id); vec_meta_info_[id].var_ref_count_++; } } @@ -174,7 +167,7 @@ void InterpreterCore::Convert() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { std::vector vec_temp; - for (auto& item : vec_instruction_[i].output_index_) { + for (auto& item : vec_instruction_[i].Outputs()) { for (auto id : item.second) { vec_temp = interpretercore::merge_vector(vec_temp, input_var2op_info_[id]); @@ -205,7 +198,7 @@ void InterpreterCore::Convert() { BuildSkipShareLoDInfo(); for (size_t i = 0; i < vec_instruction_.size(); ++i) { - gc_event_.emplace_back(vec_instruction_[i].execution_ctx_.get()->GetPlace(), + gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(), platform::GenerateDeviceEventFlag()); } @@ -215,15 +208,14 @@ void InterpreterCore::Convert() { } bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { - if (!global_scope_->vec_meta_info_[var_index].vardesc_) { - return input_var2op_info_[var_index].size() == 1; + if (!global_scope_->VarDesc(var_index)) { + return input_var2op_info_.at(var_index).size() == 1; } else { int is_input_cnt = 0; - for (auto inst_id : input_var2op_info_[var_index]) { + for (auto inst_id : input_var2op_info_.at(var_index)) { OpInOutInfo info; - info.Build(vec_instruction_[inst_id].kernel_func_.operator_base_); - if (info.IsInArgBufferNeeded( - global_scope_->vec_meta_info_[var_index].vardesc_->Name())) { + info.Build(vec_instruction_.at(inst_id).OpBase()); + if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) { is_input_cnt++; } } @@ -233,35 +225,31 @@ bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { void InterpreterCore::BuildInplace() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { - if (!vec_instruction_[i] - .kernel_func_.operator_base_->Info() - .infer_inplace_) { + auto& instr = vec_instruction_[i]; + auto* op_base = instr.OpBase(); + if (!op_base->Info().infer_inplace_) { continue; } - auto in_to_outs = - vec_instruction_[i].kernel_func_.operator_base_->Info().infer_inplace_( - platform::is_gpu_place(vec_instruction_[i].dev_ctx_->GetPlace())); + auto in_to_outs = op_base->Info().infer_inplace_( + platform::is_gpu_place(instr.DeviceContext().GetPlace())); + auto& inputs = instr.Inputs(); + auto& outputs = instr.Outputs(); for (auto& pair : in_to_outs) { - auto iter = vec_instruction_[i].input_index_.find(pair.first); - if (iter != vec_instruction_[i].input_index_.end()) { + auto iter = inputs.find(pair.first); + if (iter != inputs.end()) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { - auto iterout = vec_instruction_[i].output_index_.find(pair.second); - if (iterout != vec_instruction_[i].output_index_.end()) { - auto invar = global_scope_->var_list[iter->second[0]]; - auto outvar = global_scope_->var_list[iterout->second[0]]; + auto iterout = outputs.find(pair.second); + if (iterout != outputs.end()) { + auto invar = global_scope_->Var(iter->second[0]); + auto outvar = global_scope_->Var(iterout->second[0]); if (invar && outvar) { - vec_instruction_[i].vec_inplace_in_to_out_.emplace_back(invar, - outvar); - VLOG(3) << "inplace " - << vec_instruction_[i].kernel_func_.operator_base_->Type() - << " " - << global_scope_->vec_meta_info_[iter->second[0]] - .vardesc_->Name() + instr.AddInplace(invar, outvar); + VLOG(3) << "inplace " << op_base->Type() << " " + << global_scope_->VarDesc(iter->second[0])->Name() << " -> " - << global_scope_->vec_meta_info_[iterout->second[0]] - .vardesc_->Name() + << global_scope_->VarDesc(iterout->second[0])->Name() << std::endl; } } @@ -274,48 +262,35 @@ void InterpreterCore::BuildInplace() { void InterpreterCore::BuildAndCacheInstructionCtx( Instruction* instr_node, const VariableScope& var_scope, const platform::Place& place) { - auto op_base = instr_node->kernel_func_.operator_base_; - VariableValueMap ins_map; - for (auto& var_name_item : instr_node->input_index_) { + for (auto& var_name_item : instr_node->Inputs()) { std::vector input_vars; input_vars.reserve(var_name_item.second.size()); for (auto& id : var_name_item.second) { - input_vars.emplace_back(var_scope.var_list[id]); + input_vars.emplace_back(var_scope.Var(id)); } ins_map.emplace(var_name_item.first, std::move(input_vars)); } VariableValueMap outs_map; - for (auto& var_name_item : instr_node->output_index_) { + for (auto& var_name_item : instr_node->Outputs()) { std::vector out_vars; out_vars.reserve(var_name_item.second.size()); for (auto& id : var_name_item.second) { - out_vars.emplace_back(var_scope.var_list[id]); + out_vars.emplace_back(var_scope.Var(id)); } outs_map.emplace(var_name_item.first, std::move(out_vars)); } - - instr_node->runtime_ctx_.reset(new RuntimeContext({}, {})); - instr_node->runtime_ctx_->inputs.swap(ins_map); - instr_node->runtime_ctx_->outputs.swap(outs_map); - - instr_node->infershape_ctx_.reset(new InterpretercoreInferShapeContext( - *op_base, *instr_node->runtime_ctx_.get())); - - auto* dev_ctx = instr_node->dev_ctx_; - Scope scope; - - instr_node->execution_ctx_.reset(new ExecutionContext( - *op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get())); + // set runtime_ctx and infershape_ctx_ + instr_node->ResetContext(ins_map, outs_map); } void InterpreterCore::BuildSkipShareLoDInfo() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { bool can_skip_lod = true; - for (auto& input : vec_instruction_[i].runtime_ctx_.get()->inputs) { + for (auto& input : vec_instruction_[i].InnerRuntimeContext()->inputs) { for (auto& var : input.second) { if (var->IsType()) { if (var->Get().lod().size() != 0) { @@ -328,23 +303,21 @@ void InterpreterCore::BuildSkipShareLoDInfo() { } } } - vec_instruction_[i].infershape_ctx_.get()->SetSkipLoD(can_skip_lod); + vec_instruction_[i].InnerInferShapeContext()->SetSkipLoD(can_skip_lod); } } void InterpreterCore::RunInstruction(const Instruction& instr_node) { - VLOG(3) << "RunInstruction: " - << instr_node.kernel_func_.operator_base_->Type(); + VLOG(3) << "RunInstruction: " << instr_node.OpBase()->Type(); { platform::RecordEvent infershape_event("InferShape"); - static_cast( - instr_node.kernel_func_.operator_base_) - ->InferShape(instr_node.infershape_ctx_.get()); + static_cast(instr_node.OpBase()) + ->InferShape(instr_node.InnerInferShapeContext().get()); } if (FLAGS_new_executor_use_inplace) { - for (auto& pair : instr_node.vec_inplace_in_to_out_) { + for (auto& pair : instr_node.InplaceInfo()) { const auto& in = paddle::framework::details::GetTensorFromVar(pair.first); auto* out = paddle::framework::details::GetMutableTensorFromVar(pair.second); @@ -355,7 +328,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { } { platform::RecordEvent compute_event("Compute"); - instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get()); + instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); } } @@ -369,7 +342,7 @@ void InterpreterCore::ExecuteInstructionList( for (size_t i = 0; i < dependecy_count_.size(); ++i) { if (dependecy_count_[i] == 0) { - async_work_queue_.AddTask(vec_instr[i].type_, + async_work_queue_.AddTask(vec_instr.at(i).KernelType(), [&, i] { RunInstructionAsync(i); }); } } @@ -391,43 +364,43 @@ void InterpreterCore::ExecuteInstructionList( void InterpreterCore::RunNextInstructions( const Instruction& instr, std::queue* reserved_next_ops) { - auto& next_instr = instr.next_instruction_; + auto& next_instr = instr.NextInstructions(); auto& atomic_deps = async_work_queue_.AtomicDeps(); auto IsReady = [&](size_t next_id) { return atomic_deps[next_id]->fetch_sub(1, std::memory_order_relaxed) == 1; }; - if (instr.type_ == OpFuncType::kQueueAsync) { + if (instr.KernelType() == OpFuncType::kQueueAsync) { // move all sync_ops into other threads - for (auto next_id : next_instr.synchronize_run_) { + for (auto next_id : next_instr.SyncRunIds()) { if (IsReady(next_id)) { async_work_queue_.AddTask( - vec_instruction_[next_id].type_, + vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } } // keep all async_ops running in current thread - for (auto next_id : next_instr.direct_run_) { + for (auto next_id : next_instr.DirectRunIds()) { if (IsReady(next_id)) { reserved_next_ops->push(next_id); } } - for (auto next_id : next_instr.event_wait_run_) { + for (auto next_id : next_instr.EventRunIds()) { if (IsReady(next_id)) { reserved_next_ops->push(next_id); } } } else { // move async_ops into async_thread - for (auto next_id : next_instr.event_wait_run_) { + for (auto next_id : next_instr.EventRunIds()) { if (IsReady(next_id)) { async_work_queue_.AddTask( - vec_instruction_[next_id].type_, + vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } } auto direct_run_ops = interpretercore::merge_vector( - next_instr.synchronize_run_, next_instr.direct_run_); + next_instr.SyncRunIds(), next_instr.DirectRunIds()); size_t first_op = 0; for (auto next_id : direct_run_ops) { if (IsReady(next_id)) { @@ -438,7 +411,7 @@ void InterpreterCore::RunNextInstructions( } // move rest ops into other threads async_work_queue_.AddTask( - vec_instruction_[next_id].type_, + vec_instruction_[next_id].KernelType(), [&, next_id] { RunInstructionAsync(next_id); }); } } @@ -452,8 +425,8 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { while (!ready_ops.empty()) { instr_id = ready_ops.front(); ready_ops.pop(); - auto& instr_node = vec_instruction_[instr_id]; - auto* op = instr_node.kernel_func_.operator_base_; + auto& instr_node = vec_instruction_.at(instr_id); + auto* op = instr_node.OpBase(); platform::RecordEvent instruction_event(op->Type()); event_manager_.WaitEvent(instr_node, place_); @@ -486,28 +459,27 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { op_run_number_.fetch_add(1, std::memory_order_relaxed); // GC infomation - CheckGC(instr_id, instr_node.gc_check_var_list); + CheckGC(instr_node); RunNextInstructions(instr_node, &ready_ops); } } -void InterpreterCore::CheckGC(size_t instr_id, - const std::vector& gc_check_list) { +void InterpreterCore::CheckGC(const Instruction& instr) { + size_t instr_id = instr.Id(); auto& var_scope = *global_scope_; auto& atomic_var_ref = async_work_queue_.AtomicVarRef(); - for (auto var_id : gc_check_list) { + for (auto var_id : instr.GCCheckVars()) { bool is_ready = atomic_var_ref[var_id]->fetch_sub(1, std::memory_order_relaxed) == 1; - if (is_ready && var_scope.vec_meta_info_[var_id].vardesc_ && - !var_scope.vec_meta_info_[var_id].vardesc_->Persistable()) { - gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id], - vec_instruction_[instr_id].dev_ctx_); - } else if (is_ready && - var_scope.vec_meta_info_[var_id].vardesc_ == nullptr) { - gc_.Add(var_scope.var_list[var_id], gc_event_[instr_id], - vec_instruction_[instr_id].dev_ctx_); + if (is_ready && var_scope.VarDesc(var_id) && + !var_scope.VarDesc(var_id)->Persistable()) { + gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id), + &instr.DeviceContext()); + } else if (is_ready && var_scope.VarDesc(var_id) == nullptr) { + gc_.Add(var_scope.Var(var_id), gc_event_.at(instr_id), + &instr.DeviceContext()); } } } @@ -516,11 +488,11 @@ void InterpreterCore::DryRunPrepare( const std::vector& feed_tensors) { auto FeedInput = [&] { for (size_t i = 0; i < feed_names_.size(); ++i) { - auto it = global_scope_->name2id.find(feed_names_[i]); - assert(it != global_scope_->name2id.end()); + auto* feed_var = global_scope_->FindVar(feed_names_[i]); + PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound( + "feed_var shall not be nullptr.")); - auto feed_tensor = global_scope_->var_list[it->second] - ->GetMutable(); + auto feed_tensor = feed_var->GetMutable(); feed_tensor->ShareDataWith(feed_tensors[i]); } }; @@ -530,7 +502,7 @@ void InterpreterCore::DryRunPrepare( global_scope_); FeedInput(); paddle::framework::interpretercore::build_op_func_list( - place_, main_program_, &op_list_, &vec_func_list_, global_scope_); + place_, main_program_, &vec_func_list_, global_scope_); is_build_ = true; // convert vec func_list to graph Convert(); diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 9fba5f2cdce8b9f5db174794691d31790b3413de..811843db5292a74e08a4e7ea2942335b2019643b 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -67,7 +67,7 @@ class InterpreterCore { void DryRunPrepare(const std::vector& feed_tensors); - void CheckGC(size_t instr_id, const std::vector& gc_check_list); + void CheckGC(const Instruction& instr); void RunInstructionAsync(size_t instr_id); void RunNextInstructions(const Instruction& instr_id, @@ -82,16 +82,15 @@ class InterpreterCore { ProgramDesc main_program_; VariableScope* global_scope_; - std::vector vec_instruction_; + std::vector vec_func_list_; + std::vector vec_instruction_; // deconstruct before OpFuncNode + InstructionInfo instruction_info_; std::vector dependecy_count_; std::vector> input_var2op_info_; std::vector ref_coun_info_; std::vector vec_meta_info_; - std::vector vec_func_list_; - std::vector op_list_; - std::vector feed_names_; InterpreterProfiler dry_run_profiler_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 7bb0429c6228b2be2d14f2923b009367e7b4d5c3..61d1462053f4a32c5d0f3600e65fbc459ccaf39d 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -19,6 +19,7 @@ namespace paddle { namespace framework { namespace interpretercore { +using VariableIdMap = std::map>; AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps( const std::vector& dependecy_count) { @@ -132,43 +133,29 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, VariableScope* var_scope) { auto& global_block = pdesc.Block(0); - for (auto& var : global_block.AllVars()) { - if (var->Name() == framework::kEmptyVarName) { + for (auto& var_desc : global_block.AllVars()) { + auto var_name = var_desc->Name(); + if (var_name == framework::kEmptyVarName) { continue; } - if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { - var_scope->name2id[var->Name()] = var_scope->var_list.size(); - auto v = new Variable(); - InitializeVariable(v, var->GetType()); - var_scope->var_list.push_back(v); - - VariableMetaInfo info; - info.var_ref_count_ = 0; - info.vardesc_ = var; - var_scope->vec_meta_info_.push_back(info); + if (nullptr == var_scope->FindVar(var_name)) { + var_scope->AddVar(var_desc->Name(), var_desc); } else { - auto var_id = var_scope->name2id[var->Name()]; - if (nullptr == var_scope->vec_meta_info_[var_id].vardesc_) { - VLOG(3) << "update var:" << var->Name() << " desc from nullptr into " - << var; - var_scope->vec_meta_info_[var_id].vardesc_ = var; + auto* var_desc = var_scope->VarDesc(var_name); + if (nullptr == var_desc) { + VLOG(3) << "update var:" << var_name << " desc from nullptr into " + << var_desc; + var_scope->VarMetaInfo(var_name).vardesc_ = var_desc; } } } } -void build_op_func_list(const platform::Place& place, - const framework::ProgramDesc& pdesc, - std::vector* op_list, - std::vector* vec_func_list, - VariableScope* var_scope) { - auto& global_block = pdesc.Block(0); - auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); - +std::vector create_all_ops(const framework::BlockDesc& block) { std::vector ops; - for (auto& op : global_block.AllOps()) { - VLOG(3) << "Build OpFuncNode from : " << op->Type(); + for (auto& op : block.AllOps()) { + VLOG(3) << "CreateOp from : " << op->Type(); auto& info = OpInfoMap::Instance().Get(op->Type()); @@ -179,64 +166,96 @@ void build_op_func_list(const platform::Place& place, if (info.Checker() != nullptr) { info.Checker()->Check(&op_attr_map); } - // step 1. Prepare VariableValueMap of input/output auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); ops.push_back(op_base); } + return ops; +} + +std::tuple build_variable_map( + const VariableNameMap& var_name_map, VariableScope* var_scope) { + VariableValueMap name2var; + VariableIdMap name2id; + for (auto& item : var_name_map) { + std::vector vars; + std::vector ids; + vars.reserve(item.second.size()); + + for (auto& var_name : item.second) { + auto var_id = var_scope->VarId(var_name); + auto* in_var = var_scope->Var(var_id); + vars.push_back(in_var); + ids.push_back(var_id); + } + name2var[item.first] = std::move(vars); + name2id[item.first] = std::move(ids); + } + return std::make_tuple(name2var, name2id); +} + +void apply_device_guard(const OperatorBase* op_base, + const platform::Place& place, + OpKernelType* expected_kernel_key) { + bool need_change_place = + (op_base->HasAttr("op_device") && + (op_base->Attr("op_device").length() > 0)); + if (need_change_place) { + auto& op_device = op_base->Attr("op_device"); + if (op_device == "cpu" || platform::is_cpu_place(place)) { + VLOG(3) << "Switch into CPUPlace by device_guard."; + expected_kernel_key->place_ = platform::CPUPlace(); + } else if (op_device.find("gpu") != std::string::npos && + platform::is_gpu_place(place)) { + VLOG(3) << "Switch into " << place << " by device_guard."; + expected_kernel_key->place_ = place; + } else { + PADDLE_THROW( + platform::errors::Fatal("Unsupported current place %s", op_device)); + } + } +} + +void build_op_func_list(const platform::Place& place, + const framework::ProgramDesc& pdesc, + std::vector* vec_func_list, + VariableScope* var_scope) { + auto& global_block = pdesc.Block(0); + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); + // Step 1: create all ops for global block. + auto ops = create_all_ops(global_block); auto unused_var_map = get_unused_vars(global_block, ops); size_t ops_index = 0; for (auto& op : global_block.AllOps()) { - VLOG(3) << op->Type(); - // << op->Type() << endl; + VLOG(3) << "Build OpFuncNode from : " << op->Type(); auto op_base = ops[ops_index++]; - auto inputs_names = op->Inputs(); auto outputs_names = op->Outputs(); VariableValueMap ins_map; - std::map> ins_name2id; - for (auto& var_name_item : inputs_names) { - std::vector input_vars; - std::vector vec_ids; - input_vars.reserve(var_name_item.second.size()); - for (auto& var_name : var_name_item.second) { - auto it = var_scope->name2id.find(var_name); - assert(it != var_scope->name2id.end()); - input_vars.push_back(var_scope->var_list[it->second]); - vec_ids.push_back(it->second); - } - ins_map[var_name_item.first] = input_vars; - ins_name2id[var_name_item.first] = vec_ids; - } + VariableIdMap ins_name2id; + std::tie(ins_map, ins_name2id) = + build_variable_map(inputs_names, var_scope); VariableValueMap outs_map; - std::map> outs_name2id; - for (auto& var_name_item : outputs_names) { - std::vector output_vars; - std::vector vec_ids; - output_vars.reserve(var_name_item.second.size()); - for (auto& var_name : var_name_item.second) { - auto it = var_scope->name2id.find(var_name); - assert(it != var_scope->name2id.end()); - output_vars.push_back(var_scope->var_list[it->second]); - vec_ids.push_back(it->second); - } - outs_map[var_name_item.first] = output_vars; - outs_name2id[var_name_item.first] = vec_ids; - } + VariableIdMap outs_name2id; + std::tie(outs_map, outs_name2id) = + build_variable_map(outputs_names, var_scope); + // step 2: build OpFuncNode OpFuncNode op_func_node; op_func_node.input_index = ins_name2id; op_func_node.output_index = outs_name2id; - // step 2: construct RuntimeContext and analysis KernelType + // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); runtime_context.outputs.swap(outs_map); InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context); + // TODO(Aurelius84): In case of control flow ops, they are NOT inheritted + // from OperatorWithKernel. static_cast(op_base)->InferShape( &infer_shape_ctx); auto kernels_iter = all_op_kernels.find(op->Type()); @@ -256,32 +275,18 @@ void build_op_func_list(const platform::Place& place, ->GetExpectedKernelType( ExecutionContext(*op_base, scope, *dev_ctx, runtime_context)); - // consider device_guard context - bool need_change_place = - (op_base->HasAttr("op_device") && - (op_base->Attr("op_device").length() > 0)); - if (need_change_place) { - auto& op_device = op_base->Attr("op_device"); - if (op_device == "cpu" || platform::is_cpu_place(place)) { - VLOG(3) << "Switch into CPUPlace by device_guard."; - expected_kernel_key.place_ = platform::CPUPlace(); - } else if (op_device.find("gpu") != std::string::npos && - platform::is_gpu_place(place)) { - VLOG(3) << "Switch into " << place << " by device_guard."; - expected_kernel_key.place_ = place; - } else { - PADDLE_THROW( - platform::errors::Fatal("Unsupported current place %s", op_device)); - } - } + // consider device_guard() + apply_device_guard(op_base, place, &expected_kernel_key); VLOG(3) << "expected_kernel_key : " << expected_kernel_key; // step 3. Insert memcpy_op if needed VariableValueMap& ins_map_temp = runtime_context.inputs; std::unordered_set no_data_transform_index; + for (auto& var_name_item : ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; + auto& var_name = inputs_names[var_name_item.first].at(i); auto tensor_in = static_cast(&(var->Get())); if (!tensor_in->IsInitialized()) { continue; @@ -293,32 +298,19 @@ void build_op_func_list(const platform::Place& place, if (platform::is_same_place(kernel_type_for_var.place_, expected_kernel_key.place_)) { // record no need data transformer input var_id - auto& var_name = inputs_names[var_name_item.first][i]; VLOG(3) << op->Type() << " found no data_transform var: " << var_name - << " with id: " << var_scope->name2id[var_name]; - no_data_transform_index.emplace(var_scope->name2id[var_name]); + << " with id: " << var_name; + no_data_transform_index.emplace(var_scope->VarId(var_name)); } else { if (op_base->Type() == "fetch_v2") { op_base->SetAttr("deepcopy", false); } - // need trans place - // 1. add var in scope - // 2. add copy op std::string new_var_name = - "temp_1" + std::to_string(var_scope->var_list.size() + 1); - auto v = new Variable(); - v->GetMutable(); - var_scope->name2id[new_var_name] = var_scope->var_list.size(); - var_scope->var_list.push_back(v); - - VariableMetaInfo info; - info.var_ref_count_ = 0; - info.vardesc_ = nullptr; - var_scope->vec_meta_info_.push_back(info); + var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1); + var_scope->AddVar(new_var_name, nullptr); VariableNameMap copy_in_map; - auto x_iter = inputs_names.find(var_name_item.first); - copy_in_map["X"] = {x_iter->second[i]}; + copy_in_map["X"] = {var_name}; VariableNameMap copy_out_map; copy_out_map["Out"] = {new_var_name}; AttributeMap attr_map; @@ -328,23 +320,23 @@ void build_op_func_list(const platform::Place& place, : is_gpu_place(expected_kernel_key.place_) ? 1 : -1; std::map> copy_ins_name2id; - copy_ins_name2id["X"] = ins_name2id[var_name_item.first]; + copy_ins_name2id["X"] = ins_name2id.at(var_name_item.first); std::map> copy_out_name2id; - copy_out_name2id["Out"] = {var_scope->name2id[new_var_name]}; + copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)}; op_func_node.input_index[var_name_item.first][i] = - var_scope->name2id[new_var_name]; + var_scope->VarId(new_var_name); VariableValueMap copy_ins_value_map; copy_ins_value_map["X"] = {var}; VariableValueMap copy_outs_value_map; - copy_outs_value_map["Out"] = {v}; + copy_outs_value_map["Out"] = {var_scope->Var(new_var_name)}; // memcpy_d2h, memcpy_h2d auto memcpy_op_type = get_memcpy_type(kernel_type_for_var.place_, expected_kernel_key.place_); VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", - memcpy_op_type, x_iter->second[i], + memcpy_op_type, var_name, kernel_type_for_var.place_, new_var_name, expected_kernel_key.place_); auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type); @@ -385,16 +377,16 @@ void build_op_func_list(const platform::Place& place, // as kQueueSync and execute them in thread pool. copy_op_func_node.type_ = OpFuncType::kQueueSync; copy_op_func_node.dev_ctx_ = dev_ctx; - op_list->push_back(copy_op); + copy_op_func_node.operator_base_ = copy_op; vec_func_list->push_back(copy_op_func_node); - var_name_item.second[i] = v; + var_name_item.second[i] = var_scope->Var(new_var_name); } } } op_func_node.no_data_transform_index = std::move(no_data_transform_index); // step 4. Run op kernel - op_list->push_back(op_base); + op_func_node.operator_base_ = op_base; VLOG(3) << op_base->Type() << " : expected_kernel_key : " << expected_kernel_key; @@ -436,9 +428,7 @@ void build_op_func_list(const platform::Place& place, new std::deque>(); for (auto& var_name : delete_vars) { - auto it = var_scope->name2id.find(var_name); - assert(it != var_scope->name2id.end()); - auto* var = var_scope->var_list[it->second]; + auto* var = var_scope->FindVar(var_name); if (var == nullptr) { continue; } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index b1e1c02ab9513b79ae34a19b8f2d6907380716ce..976826800f4a47af97e0bc9a255871758bbe38b4 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -101,7 +101,6 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, void build_op_func_list(const platform::Place& place, const framework::ProgramDesc& pdesc, - std::vector* op_list, std::vector* vec_func_list, VariableScope* var_scope); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index e6cff353a659d77f1cf0b12692e4e4cdccc9fda2..5b922281e6f1585e91efd31f4587b9cf9592dd12 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/event.h" @@ -463,7 +464,6 @@ class InterpretercoreInferShapeContext : public InferShapeContext { struct OpKernelFunc { OpKernelComputeFunc compute_func_; - OperatorBase* operator_base_; }; struct VariableMetaInfo { @@ -471,13 +471,108 @@ struct VariableMetaInfo { paddle::framework::VarDesc* vardesc_; }; -struct VariableScope { +// TODO(Aurelius84): Consider inherit ScopeBase to unify interface. +class VariableScope { + public: + Variable* FindVar(const std::string& name) const { + if (!HasVar(name)) { + return nullptr; + } + auto var_id = VarId(name); + CheckExist(var_id); + return var_list[var_id]; + } + + bool HasVar(const std::string& name) const { + return name2id.find(name) != name2id.end(); + } + + int VarId(const std::string& name) const { + CheckExist(name); + return name2id.at(name); + } + + Variable* Var(int id) const { return var_list.at(id); } + + Variable* Var(const std::string& name) const { + return var_list.at(VarId(name)); + } + + size_t VarSize() const { return var_list.size(); } + + void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT + name2id[name] = VarSize(); + auto v = new Variable(); + if (nullptr == var_desc) { + v->GetMutable(); + } else { + InitializeVariable(v, var_desc->GetType()); + } + var_list.push_back(v); + + VariableMetaInfo info; + info.var_ref_count_ = 0; + info.vardesc_ = var_desc; + vec_meta_info_.push_back(info); + } + + void AddVar(const std::string& name, Variable& var) { // NOLINT + name2id[name] = VarSize(); + var_list.push_back(&var); + + VariableMetaInfo info; + info.var_ref_count_ = 0; + info.vardesc_ = nullptr; + vec_meta_info_.push_back(info); + } + + paddle::framework::VarDesc* VarDesc(const std::string& name) const { + return VarDesc(VarId(name)); + } + + paddle::framework::VarDesc* VarDesc(int id) const { + CheckExist(id); + return vec_meta_info_[id].vardesc_; + } + + VariableMetaInfo& VarMetaInfo(const std::string& name) { + return vec_meta_info_[VarId(name)]; + } + + void CheckExist(int id) const { + PADDLE_ENFORCE_LT(id, var_list.size(), + platform::errors::PreconditionNotMet( + "Required var_id < %d, but received var_id = %d.", + var_list.size(), id)); + } + + void CheckExist(const std::string& name) const { + PADDLE_ENFORCE_EQ( + HasVar(name), true, + platform::errors::NotFound("%s not in VariableScope.", name)); + } + + private: std::vector var_list; std::map name2id; std::vector vec_meta_info_; }; -struct NextInstruction { +class NextInstruction { + public: + void AddDirectRun(size_t id) { direct_run_.push_back(id); } + + void ADDEventRun(size_t id) { event_wait_run_.push_back(id); } + + void AddSyncRun(size_t id) { synchronize_run_.push_back(id); } + + const std::vector& DirectRunIds() const { return direct_run_; } + + const std::vector& EventRunIds() const { return event_wait_run_; } + + const std::vector& SyncRunIds() const { return synchronize_run_; } + + private: std::vector direct_run_; std::vector event_wait_run_; std::vector synchronize_run_; @@ -503,49 +598,138 @@ enum class OpFuncType { }; class RuntimeInferShapeContext; -struct Instruction { - OpKernelFunc kernel_func_; +struct OpFuncNode { + OperatorBase* operator_base_; + std::map> input_index; + std::map> output_index; + std::unordered_set no_data_transform_index; + + OpKernelComputeFunc kernel_func_; + platform::DeviceContext* dev_ctx_; // not owned + OpFuncType type_; +}; + +class Instruction { + public: + Instruction(size_t id, const OpFuncNode& op_func_node, + const platform::DeviceContext& dev_ctx) + : id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) { + PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet( + "Required id >= 0, but received id = %d", id)); + } + + size_t Id() const { return id_; } + + const std::map>& Inputs() const { + return op_func_node_.input_index; + } + + const std::map>& Outputs() const { + return op_func_node_.output_index; + } + + const std::unordered_set& NoDataTransformVars() const { + return op_func_node_.no_data_transform_index; + } + + OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; } + + OpFuncType KernelType() const { return op_func_node_.type_; } + + OperatorBase* OpBase() const { + auto* op_base = op_func_node_.operator_base_; + PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( + "op_base shall not be nullptr.")); + return op_base; + } + + NextInstruction& NextInstructions() { return next_instruction_; } + + const NextInstruction& NextInstructions() const { return next_instruction_; } + + void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); } + + const std::vector& GCCheckVars() const { return gc_check_var_list_; } + + void ResetContext(const VariableValueMap& in_vars, + const VariableValueMap& out_vars) { + runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); + infershape_ctx_.reset( + new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); + // NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an + // empty here to avoid illegal local reference. + static framework::Scope scope_; + execution_ctx_.reset( + new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); + } + + std::shared_ptr InnerRuntimeContext() const { + return runtime_ctx_; + } + + std::shared_ptr InnerInferShapeContext() + const { + return infershape_ctx_; + } + + std::shared_ptr InnerExecutionContext() const { + return execution_ctx_; + } + + const platform::DeviceContext& DeviceContext() const { return dev_ctx_; } + + const std::vector>& InplaceInfo() const { + return vec_inplace_in_to_out_; + } + + void AddInplace(Variable* in, Variable* out) { + vec_inplace_in_to_out_.emplace_back(in, out); + } + + const std::vector& InputEvents() const { return intput_events_; } + + const std::vector& OutputEvents() const { return output_events_; } + + void AddInputEvent(size_t var_id, + std::shared_ptr event, + platform::DeviceType waiter_type) { + intput_events_.emplace_back(var_id, event, waiter_type); + } + + void AddOutputEvent(size_t var_id, + std::shared_ptr event, + platform::DeviceType waiter_type) { + output_events_.emplace_back(var_id, event, waiter_type); + } + + private: + size_t id_; + const OpFuncNode& op_func_node_; // not owned + const platform::DeviceContext& dev_ctx_; // not owned + std::shared_ptr runtime_ctx_; std::shared_ptr infershape_ctx_; std::shared_ptr execution_ctx_; - std::map> input_index_; - std::map> output_index_; - - std::unordered_set no_data_transform_index_; - std::vector gc_check_var_list; + std::vector gc_check_var_list_; NextInstruction next_instruction_; std::vector intput_events_; std::vector output_events_; - platform::DeviceContext* dev_ctx_; // not owned - OpFuncType type_; - std::vector> vec_inplace_in_to_out_; }; -struct OpFuncNode { - // int unsed; - std::map> input_index; - std::map> output_index; - std::unordered_set no_data_transform_index; - - OpKernelComputeFunc kernel_func_; - platform::DeviceContext* dev_ctx_; // not owned - OpFuncType type_; -}; - namespace interpretercore { static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyD2H[] = "memcpy_d2h"; static bool IsMemcpyH2D(const Instruction& instr) { - return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D; + return instr.OpBase()->Type() == kMemcpyH2D; } static bool IsMemcpyD2H(const Instruction& instr) { - return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H; + return instr.OpBase()->Type() == kMemcpyD2H; } } // namespace interpretercore diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index a7579d54616af462ba09e1d573dc7e25eb62d176..898c2d3d75e7e31e52088808d8ab51cf6da9eb00 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -33,23 +33,16 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, auto name_list = outer_scope_->LocalVarNames(); for (auto name : name_list) { auto v = outer_scope_->Var(name); - if (global_scope_.name2id.find(name) == global_scope_.name2id.end()) { - global_scope_.name2id[name] = global_scope_.var_list.size(); - global_scope_.var_list.push_back(v); - - VariableMetaInfo info; - info.var_ref_count_ = 0; - info.vardesc_ = nullptr; - global_scope_.vec_meta_info_.push_back(info); + if (!global_scope_.HasVar(name)) { + global_scope_.AddVar(name, *v); } } } // run startup program std::vector vec_func_list; - std::vector op_list; paddle::framework::interpretercore::build_op_func_list( - place_, startup_prog, &op_list, &vec_func_list, &global_scope_); + place_, startup_prog, &vec_func_list, &global_scope_); } paddle::framework::FetchList StandaloneExecutor::Run( @@ -80,16 +73,8 @@ void StandaloneExecutor::BuildVariableOuterScope( continue; } - if (var_scope->name2id.find(var->Name()) == var_scope->name2id.end()) { - var_scope->name2id[var->Name()] = var_scope->var_list.size(); - auto v = outer_scope->Var(var->Name()); - InitializeVariable(v, var->GetType()); - var_scope->var_list.push_back(v); - - VariableMetaInfo info; - info.var_ref_count_ = 0; - info.vardesc_ = var; - var_scope->vec_meta_info_.push_back(info); + if (!var_scope->HasVar(var->Name())) { + var_scope->AddVar(var->Name(), var); } } } diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index ffc2da499e1f7b927b5629b9194746d4fe671199..d30f27169cc43d16e237eaf42637e2ad82a638ac 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -31,15 +31,15 @@ namespace framework { std::vector StreamAnalyzer::ParseEventVarIds( const Instruction& cur_instr, const Instruction& next_instr) { std::unordered_set unique_var_ids; - for (auto& item : cur_instr.output_index_) { + for (auto& item : cur_instr.Outputs()) { unique_var_ids.insert(item.second.begin(), item.second.end()); } std::vector new_event_var_ids; - for (auto& item : next_instr.input_index_) { + for (auto& item : next_instr.Inputs()) { for (auto var_id : item.second) { if (unique_var_ids.count(var_id) > 0 && - next_instr.no_data_transform_index_.count(var_id) == 0) { + next_instr.NoDataTransformVars().count(var_id) == 0) { new_event_var_ids.push_back(var_id); } } @@ -57,8 +57,7 @@ void StreamAnalyzer::AssociateInputWithEvents( var_id2event_.emplace(var_id, std::move(device_event)); } // Add events for next_instr.inputs - next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id), - waiter_type); + next_instr->AddInputEvent(var_id, var_id2event_.at(var_id), waiter_type); } } @@ -66,13 +65,13 @@ void StreamAnalyzer::Schedule(const std::vector& downstream_ops, std::vector* instructions, size_t op_index) { auto& cur_instr = instructions->at(op_index); - auto& next_instruction = cur_instr.next_instruction_; + auto& next_instruction = cur_instr.NextInstructions(); std::vector event_var_ids; for (auto next_op_id : downstream_ops) { auto& next_instr = instructions->at(next_op_id); if (IsDirectRun(cur_instr, next_instr)) { - next_instruction.direct_run_.emplace_back(next_op_id); + next_instruction.AddDirectRun(next_op_id); } else { // Always insert events between different stream auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); @@ -83,24 +82,24 @@ void StreamAnalyzer::Schedule(const std::vector& downstream_ops, AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type); if (waiter_type == platform::kCPU) { // GPU -> CPU - next_instruction.synchronize_run_.emplace_back(next_op_id); + next_instruction.AddSyncRun(next_op_id); } else { // GPU -> GPU(different stream) - next_instruction.event_wait_run_.emplace_back(next_op_id); + next_instruction.ADDEventRun(next_op_id); } } } // Create events for these cross-stream vars - VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() + VLOG(3) << cur_instr.OpBase()->Type() << " event_var_ids.size: " << event_var_ids.size(); for (auto var_id : event_var_ids) { - cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), - platform::kCUDA /*not used*/); + cur_instr.AddOutputEvent(var_id, var_id2event_.at(var_id), + platform::kCUDA /*not used*/); } } platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( - const OpFuncNode& op_func_node, const OperatorBase& op_base) { - auto& op_type = op_base.Type(); + const OpFuncNode& op_func_node) { + auto& op_type = op_func_node.operator_base_->Type(); auto* dev_ctx = op_func_node.dev_ctx_; if (op_type == interpretercore::kMemcpyH2D) { VLOG(3) << "Get dev_ctx from d2h_context_pool_"; @@ -122,13 +121,13 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( */ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, const Instruction& next_instr) { - return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ || + return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || interpretercore::IsMemcpyD2H(cur_instr) || interpretercore::IsMemcpyH2D(next_instr)); } platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { - if (instr.type_ == OpFuncType::kQueueSync) { + if (instr.KernelType() == OpFuncType::kQueueSync) { return platform::kCPU; } else { return platform::kCUDA; diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.h b/paddle/fluid/framework/new_executor/stream_analyzer.h index dc2af389e36b0f86e9b096f6309483d3ee0d2b2a..df74c9b933712ff557a91ef4ccc42227e9bbbf64 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/stream_analyzer.h @@ -32,8 +32,7 @@ class StreamAnalyzer { void Schedule(const std::vector& downstream_ops, std::vector* instructions, size_t op_index); - platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node, - const OperatorBase& op_base); + platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node); private: std::vector ParseEventVarIds(const Instruction& cur_instr,