未验证 提交 2386db87 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support multi-thread run && delete unused code of new_ir interpreter (#56148)

* add code

* fix bug

* fix bug

* delete unused code

* refine code

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 982100ab
...@@ -664,6 +664,12 @@ void NewIrDependencyBuilder::BuildDownstreamMap() { ...@@ -664,6 +664,12 @@ void NewIrDependencyBuilder::BuildDownstreamMap() {
} }
} }
void NewIrDependencyBuilder::ShareDependencyFrom(
const NewIrDependencyBuilder& src) {
std::tie(op_downstream_map_, op_happens_before_) = src.GetDependency();
is_build_ = true;
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -114,6 +114,8 @@ class NewIrDependencyBuilder : public DependencyBuilder { ...@@ -114,6 +114,8 @@ class NewIrDependencyBuilder : public DependencyBuilder {
void BuildDownstreamMap(); void BuildDownstreamMap();
void ShareDependencyFrom(const NewIrDependencyBuilder& src);
private: private:
std::vector<paddle::framework::InstructionBase*> instructions_; // not_owned std::vector<paddle::framework::InstructionBase*> instructions_; // not_owned
}; };
......
...@@ -684,7 +684,7 @@ platform::DeviceType NewIrStreamAnalyzer::GetWaiterType( ...@@ -684,7 +684,7 @@ platform::DeviceType NewIrStreamAnalyzer::GetWaiterType(
} }
} }
void NewIrStreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) { void NewIrStreamAnalyzer::ShareEventInfoFrom(const NewIrStreamAnalyzer& src) {
event_info_ = src.GetEventInfo(); event_info_ = src.GetEventInfo();
is_event_info_build_ = true; is_event_info_build_ = true;
} }
......
...@@ -138,7 +138,7 @@ class NewIrStreamAnalyzer { ...@@ -138,7 +138,7 @@ class NewIrStreamAnalyzer {
platform::DeviceType GetWaiterType( platform::DeviceType GetWaiterType(
const paddle::framework::InstructionBase* instr) const; const paddle::framework::InstructionBase* instr) const;
void ShareEventInfoFrom(const StreamAnalyzer& src); void ShareEventInfoFrom(const NewIrStreamAnalyzer& src);
std::shared_ptr< std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>> std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
......
...@@ -72,12 +72,6 @@ class InterpreterBaseImpl { ...@@ -72,12 +72,6 @@ class InterpreterBaseImpl {
virtual paddle::framework::FetchList Run( virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0; const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
// NOTE(zhangbo): This interface is only used for temporary testing and only
// for testing during the iteration process of the new IR access actuator
// version. It will be deleted in the future.
virtual paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0; virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;
virtual void ShareBuildResultsFrom(const InterpreterBaseImpl& src) = 0; virtual void ShareBuildResultsFrom(const InterpreterBaseImpl& src) = 0;
...@@ -107,6 +101,12 @@ class InterpreterBaseImpl { ...@@ -107,6 +101,12 @@ class InterpreterBaseImpl {
virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0; virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;
virtual const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const = 0;
virtual const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const = 0;
virtual bool IsSharedResultsBuild() const = 0; virtual bool IsSharedResultsBuild() const = 0;
}; };
......
...@@ -74,11 +74,6 @@ FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names, ...@@ -74,11 +74,6 @@ FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
return impl_->Run(feed_names, need_fetch); return impl_->Run(feed_names, need_fetch);
} }
FetchList InterpreterCore::BetaRun(const std::vector<std::string>& feed_names,
bool need_fetch) {
return impl_->BetaRun(feed_names, need_fetch);
}
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) { void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
impl_->ShareWorkQueueFrom(const_cast<InterpreterBaseImpl*>(src->Impl())); impl_->ShareWorkQueueFrom(const_cast<InterpreterBaseImpl*>(src->Impl()));
} }
......
...@@ -52,9 +52,6 @@ class InterpreterCore { ...@@ -52,9 +52,6 @@ class InterpreterCore {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names, paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true); bool need_fetch = true);
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true);
void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src); void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);
void ShareBuildResultsFrom(std::shared_ptr<InterpreterCore> src); void ShareBuildResultsFrom(std::shared_ptr<InterpreterCore> src);
......
...@@ -49,10 +49,6 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -49,10 +49,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names, paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override; bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override; void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
...@@ -92,10 +88,6 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -92,10 +88,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
private: private:
// build graph // build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum(); void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace( void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map, std::map<size_t, std::set<size_t>> op_downstream_map,
...@@ -103,39 +95,13 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -103,39 +95,13 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void ConstructEventForJitInput(); void ConstructEventForJitInput();
void CalculateLastLiveOps(); void CalculateLastLiveOps();
// inplace // gc
void BuildInplace(); void ClearLoDTensorArrayInLocalScope();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// cuda graph // cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names); void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture(); void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc
void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
// workqueue // workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue(); std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
...@@ -150,23 +116,12 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -150,23 +116,12 @@ class NewIRInterpreter : public InterpreterBaseImpl {
bool is_build_{false}; bool is_build_{false};
bool static_build_{false}; bool static_build_{false};
const platform::Place place_; // Note(sonder): share the op dependency and event analysis procedure.
bool is_shared_results_build_{false};
interpreter::DependencyBuilder dependency_builder_;
interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will const platform::Place place_;
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
// from variable scope // from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0}; std::atomic<size_t> unfinished_op_number_{0};
...@@ -189,9 +144,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -189,9 +144,9 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// var // var
std::map<size_t, std::set<size_t>> last_live_ops_; std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op // (*dependecy_count_)[i] contains the number of dependencies that the i-th op
// need to wait // need to wait
std::vector<size_t> dependecy_count_; std::shared_ptr<std::vector<size_t>> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_; std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_; std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
...@@ -200,8 +155,6 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -200,8 +155,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
int64_t sync_op_num_{-1}; int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_; std::vector<size_t> trace_execute_order_;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
std::vector<HookFunc> hookfuncs_; std::vector<HookFunc> hookfuncs_;
/// ======================== /// /// ======================== ///
...@@ -215,16 +168,21 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -215,16 +168,21 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void BuildInstructionDependences(); void BuildInstructionDependences();
void LoopRunImpl();
void TraceRunImpl(); void TraceRunImpl();
void TraceRunInstructionList( void TraceRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr); const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void LoopRunInstructionList( void MultiThreadRunImpl();
void MultiThreadRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr); const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void RunInstructionBaseAsync(size_t instr_id);
void RunNextInstructions(InstructionBase* instr,
SchedulingQueue* reserved_next_ops);
void RunInstructionBase(InstructionBase* instr_node); void RunInstructionBase(InstructionBase* instr_node);
void RecordMemcpyD2H(InstructionBase* instr_node); void RecordMemcpyD2H(InstructionBase* instr_node);
...@@ -237,6 +195,12 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -237,6 +195,12 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void SolvePersisableVarNames(); void SolvePersisableVarNames();
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
std::unique_ptr<::ir::Program> ir_program_{nullptr}; std::unique_ptr<::ir::Program> ir_program_{nullptr};
......
...@@ -222,11 +222,6 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -222,11 +222,6 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
} }
} }
FetchList ProgramInterpreter::BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch) {
return {};
}
void ProgramInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) { void ProgramInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog; copy_program_ = prog;
} }
...@@ -348,6 +343,18 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer() ...@@ -348,6 +343,18 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
return stream_analyzer_; return stream_analyzer_;
} }
const interpreter::NewIrDependencyBuilder&
ProgramInterpreter::GetNewIrDependencyBuilder() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
const interpreter::NewIrStreamAnalyzer&
ProgramInterpreter::GetNewIrStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
bool ProgramInterpreter::IsSharedResultsBuild() const { bool ProgramInterpreter::IsSharedResultsBuild() const {
return is_shared_results_build_; return is_shared_results_build_;
} }
......
...@@ -48,10 +48,6 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -48,10 +48,6 @@ class ProgramInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names, paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override; bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override; void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
...@@ -63,6 +59,12 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -63,6 +59,12 @@ class ProgramInterpreter : public InterpreterBaseImpl {
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override; const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
bool IsSharedResultsBuild() const override; bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override; void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
......
...@@ -1283,25 +1283,13 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api, ...@@ -1283,25 +1283,13 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
/** /**
* Using new IR in executor FLAG * Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_beta_run * Name: enable_new_ir_in_executor_trace_run
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, executor will use new IR and run in beta version.
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_beta_run,
true,
"Enable new IR in executor");
/**
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_loop_run
* Since Version: 2.6.0 * Since Version: 2.6.0
* Value Range: bool, default=false * Value Range: bool, default=false
* Example: * Example:
* Note: If Ture, executor will use new IR and run in beta version by for loop * Note: If Ture, executor will use new IR and run in beta version by for trace
* version. * version.
*/ */
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_loop_run, PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run,
false, false,
"Enable new IR in executor"); "Enable new IR in executor");
...@@ -77,7 +77,7 @@ TEST(StandaloneExecutor, run) { ...@@ -77,7 +77,7 @@ TEST(StandaloneExecutor, run) {
std::string out_name = os.str() + "_inner_var_2"; std::string out_name = os.str() + "_inner_var_2";
test_core.SetSkipGcVars({out_name}); test_core.SetSkipGcVars({out_name});
test_core.BetaRun({}); test_core.Run({});
auto out_tensor = auto out_tensor =
test_core.local_scope() == nullptr test_core.local_scope() == nullptr
...@@ -118,7 +118,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) { ...@@ -118,7 +118,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
std::string out_name = os.str() + "_inner_var_0"; std::string out_name = os.str() + "_inner_var_0";
test_core.SetSkipGcVars({out_name}); test_core.SetSkipGcVars({out_name});
test_core.BetaRun({}); test_core.Run({});
auto out_tensor = auto out_tensor =
test_core.local_scope() == nullptr test_core.local_scope() == nullptr
......
...@@ -76,7 +76,7 @@ TEST(VJP, TanhBackwardTest) { ...@@ -76,7 +76,7 @@ TEST(VJP, TanhBackwardTest) {
std::string prefix_str = os.str(); std::string prefix_str = os.str();
test_core.SetSkipGcVars( test_core.SetSkipGcVars(
{prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"});
test_core.BetaRun({}); test_core.Run({});
auto out_tensor = auto out_tensor =
test_core.local_scope() == nullptr test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>() ? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>()
...@@ -130,7 +130,7 @@ TEST(VJP, Tanh_BackwardTest) { ...@@ -130,7 +130,7 @@ TEST(VJP, Tanh_BackwardTest) {
std::string prefix_str = os.str(); std::string prefix_str = os.str();
test_core.SetSkipGcVars( test_core.SetSkipGcVars(
{prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"}); {prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"});
test_core.BetaRun({}); test_core.Run({});
auto out_tensor = auto out_tensor =
test_core.local_scope() == nullptr test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>() ? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
...@@ -184,7 +184,7 @@ TEST(VJP, MeanBackwardTest) { ...@@ -184,7 +184,7 @@ TEST(VJP, MeanBackwardTest) {
std::string prefix_str = os.str(); std::string prefix_str = os.str();
test_core.SetSkipGcVars( test_core.SetSkipGcVars(
{prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"});
test_core.BetaRun({}); test_core.Run({});
auto out_tensor = auto out_tensor =
test_core.local_scope() == nullptr test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>() ? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册