未验证 提交 ab6a3dad 编写于 作者: L Leo Chen 提交者: GitHub

fix jit input var not ready error (#48351)

* hot fix

* fix compile

* merge develop

* follow comments
上级 970db874
......@@ -316,6 +316,26 @@ inline void RunProgramAPI(
auto output_names = details::GetTensorsName(out);
auto dout_names = details::GetTensorsName(dout);
if (VLOG_IS_ON(6)) {
std::stringstream s;
s << "input_names: ";
for (auto name : input_names) {
s << name << " ";
}
s << std::endl;
s << "output_names: ";
for (auto name : output_names) {
s << name << " ";
}
s << std::endl;
s << "dout_names: ";
for (auto name : dout_names) {
s << name << " ";
}
s << std::endl;
VLOG(6) << s.str();
}
auto *forward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc *, attrs.at("forward_global_block"));
auto *backward_global_block = PADDLE_GET_CONST(
......@@ -354,6 +374,20 @@ inline void RunProgramAPI(
skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end());
// update interpretercore skip_gc_var
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
std::set<std::string> input_vars;
input_vars.insert(input_names.begin(), input_names.end());
interpreter_core->SetJitInputVars(input_vars);
if (VLOG_IS_ON(6)) {
std::stringstream s;
s << "skip_eager_delete_vars: ";
for (auto name : skip_eager_delete_vars) {
s << name << " ";
}
VLOG(6) << s.str();
}
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, false, skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
......
......@@ -32,6 +32,7 @@ struct ExecutionConfig {
size_t deivce_num_threads;
std::set<std::string> skip_gc_vars;
std::set<std::string> jit_input_vars;
ExecutionConfig(const phi::Place& place, size_t op_num);
void Log(int log_level);
......
......@@ -43,6 +43,8 @@ class StreamAnalyzer {
platform::DeviceContext* ParseDeviceContext(
const OpFuncNode& op_func_node) const;
platform::DeviceType GetWaiterType(const Instruction& instr) const;
private:
bool HasDataDependency(const Instruction& cur_instr,
const Instruction& next_instr) const;
......@@ -70,8 +72,6 @@ class StreamAnalyzer {
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info_map) const;
platform::DeviceType GetWaiterType(const Instruction& instr) const;
DownstreamRunType AnalyseRunTypeForTwoInstructions(
const Instruction& cur_instr, const Instruction& next_instr) const;
......
......@@ -311,6 +311,22 @@ void InterpreterCore::SetSkipGcVars(const std::set<std::string>& skip_gc_vars) {
execution_config_.skip_gc_vars = skip_gc_vars;
}
void InterpreterCore::SetJitInputVars(
const std::set<std::string>& jit_input_vars) {
PADDLE_ENFORCE_EQ(
execution_config_.jit_input_vars.empty(),
true,
platform::errors::PreconditionNotMet(
"execution_config_.jit_input_vars can only be initialized once, now "
"execution_config_.jit_input_vars is "
"not empty, do not call SetJitInputVars method repeatedly."));
execution_config_.jit_input_vars = jit_input_vars;
}
const std::set<std::string>& InterpreterCore::JitInputVars() const {
return execution_config_.jit_input_vars;
}
const VariableScope* InterpreterCore::GetVariableScope() const {
return &var_scope_;
}
......@@ -563,6 +579,30 @@ void InterpreterCore::Convert(
stream_analyzer_.ConstructEvents(dependency_builder_, &vec_instruction_);
// add event for the input var of jit program, since there are async copied
// from gpu_pinned place to gpu place on compute stream.
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
auto& inst = vec_instruction_[i];
if (inst.OpBase()->Type() == interpreter::kMemcpyD2H &&
platform::is_gpu_place(place_)) {
for (auto& item : inst.Inputs()) {
for (auto var_id : item.second) {
auto name = var_scope_.GetNameById(var_id);
if (JitInputVars().count(name)) {
auto device_event = std::make_shared<platform::DeviceEvent>(
place_, platform::GenerateDeviceEventFlag());
VLOG(4) << "Add input event for input: " << name << " of "
<< inst.OpBase()->Type();
inst.AddEventToWait(
i, device_event, stream_analyzer_.GetWaiterType(inst));
}
}
}
}
}
}
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
Instruction& instr = vec_instruction_[op_idx];
......@@ -838,6 +878,19 @@ void InterpreterCore::ExecuteInstructionList(
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
// NOTE(zhiqiu): hot fix for jit input var
if (vec_instr.at(i).OpBase()->Type() == interpreter::kMemcpyD2H) {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* default_dev_ctx = pool.Get(place_);
for (auto& event : vec_instr.at(i).EventsToWait()) {
platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(3) << "Record event on default stream in jit_input_var at op: "
<< vec_instr.at(i).OpBase()->Type();
event.event_->Record(default_dev_ctx);
}
}
async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
[this, i] { RunInstructionAsync(i); });
}
......
......@@ -67,6 +67,10 @@ class InterpreterCore {
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars);
const std::set<std::string>& JitInputVars() const;
void SetJitInputVars(const std::set<std::string>& jit_input_vars);
const VariableScope* GetVariableScope() const;
void reset_scope(Scope* new_scope);
......@@ -154,6 +158,8 @@ class InterpreterCore {
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
// for jit
};
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
......@@ -310,6 +310,10 @@ class Instruction {
events_to_wait_.emplace_back(instr_id, event, waiter_type);
}
const std::vector<EventInter>& EventsToWait() const {
return events_to_wait_;
}
void AddNextInstrInDifferentThread(size_t id) {
next_instrs_in_different_thread.push_back(id);
}
......
......@@ -384,10 +384,13 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::string dtype = is_no_need_buffer_var
? "unknown_dtype"
: GetDtype(*scope, var_name);
std::string place = is_no_need_buffer_var
? "unknown_place"
: GetPlace(*scope, var_name);
ss << ":" << dtype;
ss << "[" << GetDimsDebug(*scope, var_name, true) << "]";
ss << "(" << GetLoDDebug(*scope, var_name) << ")";
ss << "(" << GetPlace(*scope, var_name) << ")";
ss << "(" << place << ")";
}
}
if (i != input.second.size() - 1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册