未验证 提交 0601c2c9 编写于 作者: S Sonder 提交者: GitHub

【静态图性能优化】Share event (#55650)

* add sharing event info

* add sharing event info

* fix

* remove const

* add flag

* fix
上级 581d05bb
......@@ -72,36 +72,36 @@ inline std::string RunTypeToString(DownstreamRunType run_type) {
}
}
void StreamAnalyzer::ConstructEvents(
std::vector<Instruction>* instructions) const {
std::vector<Instruction> cross_step_merged_instructions = *instructions;
for (const Instruction& instr : *instructions) {
cross_step_merged_instructions.emplace_back(instr);
}
void StreamAnalyzer::ConstructEvents(std::vector<Instruction>* instructions) {
if (!is_event_info_build_) {
std::vector<Instruction> cross_step_merged_instructions = *instructions;
for (const Instruction& instr : *instructions) {
cross_step_merged_instructions.emplace_back(instr);
}
DependencyBuilder dependency_builder;
dependency_builder.Build(cross_step_merged_instructions);
const std::map<size_t, std::set<size_t>>& downstream_map =
dependency_builder.OpDownstreamMap();
const size_t instr_num = cross_step_merged_instructions.size();
std::vector<std::vector<std::vector<size_t>>> run_type_info(
instr_num,
std::vector<std::vector<size_t>>(
/*number_of_run_type = */ 2)); // instr_id -> run_type ->
// next_instr_id
AnalyseAllRunType(
cross_step_merged_instructions, downstream_map, &run_type_info);
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>
event_info; // DeviceContext -> waiter_instr_id -> recorder_instr_ids
AnalyseAllEventInfo(
cross_step_merged_instructions, run_type_info, &event_info);
ShrinkEventInfo(dependency_builder, &event_info);
DependencyBuilder dependency_builder;
dependency_builder.Build(cross_step_merged_instructions);
const std::map<size_t, std::set<size_t>>& downstream_map =
dependency_builder.OpDownstreamMap();
const size_t instr_num = cross_step_merged_instructions.size();
std::vector<std::vector<std::vector<size_t>>> run_type_info(
instr_num,
std::vector<std::vector<size_t>>(
/*number_of_run_type = */ 2)); // instr_id -> run_type ->
// next_instr_id
AnalyseAllRunType(
cross_step_merged_instructions, downstream_map, &run_type_info);
AnalyseAllEventInfo(
cross_step_merged_instructions, run_type_info, event_info_.get());
ShrinkEventInfo(dependency_builder, event_info_.get());
is_event_info_build_ = true;
}
// Construct events
std::map<size_t, std::shared_ptr<DeviceEvent>> instr2event;
for (auto& context_item : event_info) {
for (auto& context_item : *event_info_) {
for (auto& waiter_item : context_item.second) {
size_t waiter_instr_id = waiter_item.first;
std::set<size_t>& recorder_instr_ids = waiter_item.second;
......@@ -481,6 +481,16 @@ DownstreamRunType StreamAnalyzer::AnalyseRunTypeForTwoInstructions(
return DownstreamRunType::kDirectRun;
}
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
StreamAnalyzer::GetEventInfo() const {
return event_info_;
}
void StreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) {
event_info_ = src.GetEventInfo();
is_event_info_build_ = true;
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -33,17 +33,26 @@ class StreamAnalyzer {
using DeviceContext = platform::DeviceContext;
using Place = platform::Place;
explicit StreamAnalyzer(const Place& place) : place_(place) {}
explicit StreamAnalyzer(const Place& place) : place_(place) {
event_info_ = std::make_shared<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>();
}
~StreamAnalyzer() {}
void ConstructEvents(std::vector<Instruction>* instructions) const;
void ConstructEvents(std::vector<Instruction>* instructions);
platform::DeviceContext* ParseDeviceContext(
const OpFuncNode& op_func_node) const;
platform::DeviceType GetWaiterType(const Instruction& instr) const;
void ShareEventInfoFrom(const StreamAnalyzer& src);
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
GetEventInfo() const;
private:
bool HasDataDependency(const Instruction& cur_instr,
const Instruction& next_instr) const;
......@@ -76,6 +85,10 @@ class StreamAnalyzer {
const Instruction& cur_instr, const Instruction& next_instr) const;
const Place place_;
bool is_event_info_build_{false};
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
event_info_;
};
} // namespace interpreter
......
......@@ -104,6 +104,10 @@ class InterpreterBaseImpl {
const = 0;
virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;
virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;
virtual bool IsSharedResultsBuild() const = 0;
};
inline void SetDeviceId(const platform::Place& place) {
......
......@@ -369,6 +369,16 @@ std::shared_ptr<std::vector<size_t>> NewIRInterpreter::GetDependencyCount()
"GetDependencyCount is not implemented in NewIRInterpreter."));
}
const interpreter::StreamAnalyzer& NewIRInterpreter::GetStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetStreamAnalyzer is not implemented in NewIRInterpreter."));
}
bool NewIRInterpreter::IsSharedResultsBuild() const {
PADDLE_THROW(platform::errors::Unimplemented(
"IsSharedResultsBuild is not implemented in NewIRInterpreter."));
}
bool NewIRInterpreter::BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index) {
if (!var_scope_.VarDesc(var_index)) {
......
......@@ -60,6 +60,10 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) override;
......
......@@ -193,6 +193,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
RunImpl();
}
is_build_ = true;
is_shared_results_build_ = true;
} else {
RunImpl();
}
......@@ -291,11 +292,16 @@ void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
}
void ProgramInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
if (is_shared_results_build_ || !src.IsSharedResultsBuild()) {
return;
}
// share op dependency
dependency_builder_.ShareDependencyFrom(src.GetDependencyBuilder());
dependecy_count_ = src.GetDependencyCount();
is_shared_ = true;
VLOG(8) << "Share BuildResults from InterpreterCore(" << &src
// share event analysis
stream_analyzer_.ShareEventInfoFrom(src.GetStreamAnalyzer());
is_shared_results_build_ = true;
VLOG(8) << "Share Build Results from InterpreterCore(" << &src
<< ") to InterpreterCore(" << this << ")";
}
......@@ -337,6 +343,15 @@ std::shared_ptr<std::vector<size_t>> ProgramInterpreter::GetDependencyCount()
return dependecy_count_;
}
const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
const {
return stream_analyzer_;
}
bool ProgramInterpreter::IsSharedResultsBuild() const {
return is_shared_results_build_;
}
void ProgramInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
......@@ -532,7 +547,7 @@ void ProgramInterpreter::BuildOperatorDependences() {
// and set the dependecy_count_
size_t instr_num = vec_instruction_.size();
dependecy_count_ = GetDependencyCount();
if (!is_shared_) {
if (!is_shared_results_build_) {
dependecy_count_->assign(instr_num, 0);
}
......@@ -571,7 +586,7 @@ void ProgramInterpreter::BuildOperatorDependences() {
}
}
if (!is_shared_) {
if (!is_shared_results_build_) {
for (size_t next_instr_id : next_instr_ids) {
++(*dependecy_count_)[next_instr_id];
}
......@@ -1336,6 +1351,7 @@ void ProgramInterpreter::Prepare(
}
BuildSkipShareLoDInfo();
is_build_ = true;
is_shared_results_build_ = true;
}
// NOTE: Because feed_tensor will be GC after
// paddle::framework::BuildOpFuncList, so we should
......
......@@ -61,6 +61,10 @@ class ProgramInterpreter : public InterpreterBaseImpl {
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) override;
......@@ -134,9 +138,8 @@ class ProgramInterpreter : public InterpreterBaseImpl {
bool is_build_{false};
bool static_build_{false};
// Note(sonder): share the op dependency,
// event analyzer, thread scheduling and GC.
bool is_shared_{false};
// Note(sonder): share the op dependency and event analysis procedure.
bool is_shared_results_build_{false};
const platform::Place place_;
const BlockDesc& block_; // not owned
......
......@@ -95,33 +95,24 @@ paddle::framework::FetchList StandaloneExecutor::Run(
const auto& jobs = plan_.JobList();
std::map<std::string, size_t> type_to_first_id;
if (!is_interpretercore_build_result_shared_) {
std::map<std::string, std::vector<size_t>> type_to_id;
type_to_first_id[jobs[0]->Type()] = 0;
for (size_t job_idx = 1; job_idx < jobs.size(); ++job_idx) {
interpretercores_[job_idx]->ShareWorkQueueFrom(interpretercores_[0]);
// TODO(Ruibiao): Share other build result, e.g., kernel choosing, data
// transfer, op dependency, thread scheduling, GC, event analyzer, and so
// on.
type_to_id[jobs[job_idx]->Type()].emplace_back(job_idx);
}
is_interpretercore_build_result_shared_ = true;
// Note(sonder): For the same type of job, share the build result of the
// first job to other jobs. The shared build result includes op dependency,
// event analyzer, thread scheduling and GC.
for (const auto& pair : type_to_id) {
const auto& ids = pair.second;
for (size_t i = 1; i < ids.size(); ++i) {
interpretercores_[ids[i]]->ShareBuildResultsFrom(
interpretercores_[ids[0]]);
if (type_to_first_id.count(jobs[job_idx]->Type()) == 0) {
type_to_first_id[jobs[job_idx]->Type()] = job_idx;
}
}
is_interpretercore_build_result_shared_ = true;
}
for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) {
const auto& job = jobs[job_idx];
const std::string& job_type = job->Type();
platform::RecordEvent record_event(
job_type + "-" + std::to_string(job->MicroBatchId()),
platform::TracerEventType::UserDefined,
......@@ -129,7 +120,12 @@ paddle::framework::FetchList StandaloneExecutor::Run(
VLOG(6) << "Run job (" << job_idx << "), type = " << job_type
<< ", micro_batch_id =" << job->MicroBatchId();
// Note(sonder): Share build results don't work for new IR now.
if (type_to_first_id.count(job_type) != 0 &&
!FLAGS_enable_new_ir_in_executor) {
interpretercores_[job_idx]->ShareBuildResultsFrom(
interpretercores_[type_to_first_id[job_type]]);
}
interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册