From ab6a3dadecbf823dc74bed602c5761f84acd3673 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 26 Nov 2022 13:00:49 +0800 Subject: [PATCH] fix jit input var not ready error (#48351) * hot fix * fix compile * merge develop * follow comments --- .../eager/to_static/run_program_op_node.h | 34 ++++++++++++ .../interpreter/execution_config.h | 1 + .../interpreter/stream_analyzer.h | 4 +- .../framework/new_executor/interpretercore.cc | 53 +++++++++++++++++++ .../framework/new_executor/interpretercore.h | 6 +++ .../new_executor/new_executor_defs.h | 4 ++ paddle/fluid/framework/operator.cc | 5 +- 7 files changed, 104 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index db3db215e2b..46635b16aed 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -316,6 +316,26 @@ inline void RunProgramAPI( auto output_names = details::GetTensorsName(out); auto dout_names = details::GetTensorsName(dout); + if (VLOG_IS_ON(6)) { + std::stringstream s; + s << "input_names: "; + for (auto name : input_names) { + s << name << " "; + } + s << std::endl; + s << "output_names: "; + for (auto name : output_names) { + s << name << " "; + } + s << std::endl; + s << "dout_names: "; + for (auto name : dout_names) { + s << name << " "; + } + s << std::endl; + VLOG(6) << s.str(); + } + auto *forward_global_block = PADDLE_GET_CONST( paddle::framework::BlockDesc *, attrs.at("forward_global_block")); auto *backward_global_block = PADDLE_GET_CONST( @@ -354,6 +374,20 @@ inline void RunProgramAPI( skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end()); // update interpretercore skip_gc_var interpreter_core->SetSkipGcVars(skip_eager_delete_vars); + + std::set input_vars; + input_vars.insert(input_names.begin(), input_names.end()); + interpreter_core->SetJitInputVars(input_vars); + + if (VLOG_IS_ON(6)) { + std::stringstream s; + s << "skip_eager_delete_vars: "; + for (auto name : skip_eager_delete_vars) { + s << name << " "; + } + VLOG(6) << s.str(); + } + interpretercore_info_cache.UpdateSkipEagerDeleteVars( program_id, false, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); diff --git a/paddle/fluid/framework/new_executor/interpreter/execution_config.h b/paddle/fluid/framework/new_executor/interpreter/execution_config.h index 3721766700a..6934723146e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/execution_config.h +++ b/paddle/fluid/framework/new_executor/interpreter/execution_config.h @@ -32,6 +32,7 @@ struct ExecutionConfig { size_t deivce_num_threads; std::set skip_gc_vars; + std::set jit_input_vars; ExecutionConfig(const phi::Place& place, size_t op_num); void Log(int log_level); diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index 1641867d670..c7634190733 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -43,6 +43,8 @@ class StreamAnalyzer { platform::DeviceContext* ParseDeviceContext( const OpFuncNode& op_func_node) const; + platform::DeviceType GetWaiterType(const Instruction& instr) const; + private: bool HasDataDependency(const Instruction& cur_instr, const Instruction& next_instr) const; @@ -70,8 +72,6 @@ class StreamAnalyzer { std::map>>* event_info_map) const; - platform::DeviceType GetWaiterType(const Instruction& instr) const; - DownstreamRunType AnalyseRunTypeForTwoInstructions( const Instruction& cur_instr, const Instruction& next_instr) const; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index c31a5e612ff..81cbea8efaa 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -311,6 +311,22 @@ void InterpreterCore::SetSkipGcVars(const std::set& skip_gc_vars) { execution_config_.skip_gc_vars = skip_gc_vars; } +void InterpreterCore::SetJitInputVars( + const std::set& jit_input_vars) { + PADDLE_ENFORCE_EQ( + execution_config_.jit_input_vars.empty(), + true, + platform::errors::PreconditionNotMet( + "execution_config_.jit_input_vars can only be initialized once, now " + "execution_config_.jit_input_vars is " + "not empty, do not call SetJitInputVars method repeatedly.")); + execution_config_.jit_input_vars = jit_input_vars; +} + +const std::set& InterpreterCore::JitInputVars() const { + return execution_config_.jit_input_vars; +} + const VariableScope* InterpreterCore::GetVariableScope() const { return &var_scope_; } @@ -563,6 +579,30 @@ void InterpreterCore::Convert( stream_analyzer_.ConstructEvents(dependency_builder_, &vec_instruction_); + // add event for the input var of jit program, since there are async copied + // from gpu_pinned place to gpu place on compute stream. + for (size_t i = 0; i < dependecy_count_.size(); ++i) { + if (dependecy_count_[i] == 0) { + auto& inst = vec_instruction_[i]; + if (inst.OpBase()->Type() == interpreter::kMemcpyD2H && + platform::is_gpu_place(place_)) { + for (auto& item : inst.Inputs()) { + for (auto var_id : item.second) { + auto name = var_scope_.GetNameById(var_id); + if (JitInputVars().count(name)) { + auto device_event = std::make_shared( + place_, platform::GenerateDeviceEventFlag()); + VLOG(4) << "Add input event for input: " << name << " of " + << inst.OpBase()->Type(); + inst.AddEventToWait( + i, device_event, stream_analyzer_.GetWaiterType(inst)); + } + } + } + } + } + } + // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { Instruction& instr = vec_instruction_[op_idx]; @@ -838,6 +878,19 @@ void InterpreterCore::ExecuteInstructionList( for (size_t i = 0; i < dependecy_count_.size(); ++i) { if (dependecy_count_[i] == 0) { + // NOTE(zhiqiu): hot fix for jit input var + if (vec_instr.at(i).OpBase()->Type() == interpreter::kMemcpyD2H) { + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto* default_dev_ctx = pool.Get(place_); + for (auto& event : vec_instr.at(i).EventsToWait()) { + platform::RecordEvent record( + "RecordStreamEvent", platform::TracerEventType::UserDefined, 10); + VLOG(3) << "Record event on default stream in jit_input_var at op: " + << vec_instr.at(i).OpBase()->Type(); + event.event_->Record(default_dev_ctx); + } + } async_work_queue_->AddTask(vec_instr.at(i).KernelType(), [this, i] { RunInstructionAsync(i); }); } diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index eb3c6022c09..ceb5ad2c727 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -67,6 +67,10 @@ class InterpreterCore { void SetSkipGcVars(const std::set& skip_gc_vars); + const std::set& JitInputVars() const; + + void SetJitInputVars(const std::set& jit_input_vars); + const VariableScope* GetVariableScope() const; void reset_scope(Scope* new_scope); @@ -154,6 +158,8 @@ class InterpreterCore { std::vector> deps_; std::vector> refs_; + + // for jit }; std::shared_ptr CreateInterpreterCore( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index a62cff85e59..f42270f34a2 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -310,6 +310,10 @@ class Instruction { events_to_wait_.emplace_back(instr_id, event, waiter_type); } + const std::vector& EventsToWait() const { + return events_to_wait_; + } + void AddNextInstrInDifferentThread(size_t id) { next_instrs_in_different_thread.push_back(id); } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9ac1b871bf7..d5af849a53c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -384,10 +384,13 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { std::string dtype = is_no_need_buffer_var ? "unknown_dtype" : GetDtype(*scope, var_name); + std::string place = is_no_need_buffer_var + ? "unknown_place" + : GetPlace(*scope, var_name); ss << ":" << dtype; ss << "[" << GetDimsDebug(*scope, var_name, true) << "]"; ss << "(" << GetLoDDebug(*scope, var_name) << ")"; - ss << "(" << GetPlace(*scope, var_name) << ")"; + ss << "(" << place << ")"; } } if (i != input.second.size() - 1) { -- GitLab