未验证 提交 2386db87 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support multi-thread run && delete unused code of new_ir interpreter (#56148)

* add code

* fix bug

* fix bug

* delete unused code

* refine code

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 982100ab
......@@ -664,6 +664,12 @@ void NewIrDependencyBuilder::BuildDownstreamMap() {
}
}
void NewIrDependencyBuilder::ShareDependencyFrom(
const NewIrDependencyBuilder& src) {
std::tie(op_downstream_map_, op_happens_before_) = src.GetDependency();
is_build_ = true;
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -114,6 +114,8 @@ class NewIrDependencyBuilder : public DependencyBuilder {
void BuildDownstreamMap();
void ShareDependencyFrom(const NewIrDependencyBuilder& src);
private:
std::vector<paddle::framework::InstructionBase*> instructions_; // not_owned
};
......
......@@ -684,7 +684,7 @@ platform::DeviceType NewIrStreamAnalyzer::GetWaiterType(
}
}
void NewIrStreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) {
void NewIrStreamAnalyzer::ShareEventInfoFrom(const NewIrStreamAnalyzer& src) {
event_info_ = src.GetEventInfo();
is_event_info_build_ = true;
}
......
......@@ -138,7 +138,7 @@ class NewIrStreamAnalyzer {
platform::DeviceType GetWaiterType(
const paddle::framework::InstructionBase* instr) const;
void ShareEventInfoFrom(const StreamAnalyzer& src);
void ShareEventInfoFrom(const NewIrStreamAnalyzer& src);
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
......
......@@ -72,12 +72,6 @@ class InterpreterBaseImpl {
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
// NOTE(zhangbo): This interface is only used for temporary testing and only
// for testing during the iteration process of the new IR access actuator
// version. It will be deleted in the future.
virtual paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;
virtual void ShareBuildResultsFrom(const InterpreterBaseImpl& src) = 0;
......@@ -107,6 +101,12 @@ class InterpreterBaseImpl {
virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;
virtual const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const = 0;
virtual const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const = 0;
virtual bool IsSharedResultsBuild() const = 0;
};
......
......@@ -74,11 +74,6 @@ FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
return impl_->Run(feed_names, need_fetch);
}
FetchList InterpreterCore::BetaRun(const std::vector<std::string>& feed_names,
bool need_fetch) {
return impl_->BetaRun(feed_names, need_fetch);
}
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
impl_->ShareWorkQueueFrom(const_cast<InterpreterBaseImpl*>(src->Impl()));
}
......
......@@ -52,9 +52,6 @@ class InterpreterCore {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true);
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true);
void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);
void ShareBuildResultsFrom(std::shared_ptr<InterpreterCore> src);
......
......@@ -49,10 +49,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
......@@ -92,10 +88,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
private:
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
......@@ -103,39 +95,13 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void ConstructEventForJitInput();
void CalculateLastLiveOps();
// inplace
void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// gc
void ClearLoDTensorArrayInLocalScope();
// cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc
void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
// workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
......@@ -150,23 +116,12 @@ class NewIRInterpreter : public InterpreterBaseImpl {
bool is_build_{false};
bool static_build_{false};
const platform::Place place_;
interpreter::DependencyBuilder dependency_builder_;
interpreter::StreamAnalyzer stream_analyzer_;
// Note(sonder): share the op dependency and event analysis procedure.
bool is_shared_results_build_{false};
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
const platform::Place place_;
// from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0};
......@@ -189,9 +144,9 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// (*dependecy_count_)[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::shared_ptr<std::vector<size_t>> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
......@@ -200,8 +155,6 @@ class NewIRInterpreter : public InterpreterBaseImpl {
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
std::vector<HookFunc> hookfuncs_;
/// ======================== ///
......@@ -215,16 +168,21 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void BuildInstructionDependences();
void LoopRunImpl();
void TraceRunImpl();
void TraceRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void LoopRunInstructionList(
void MultiThreadRunImpl();
void MultiThreadRunInstructionList(
const std::vector<std::unique_ptr<InstructionBase>>& vec_instr);
void RunInstructionBaseAsync(size_t instr_id);
void RunNextInstructions(InstructionBase* instr,
SchedulingQueue* reserved_next_ops);
void RunInstructionBase(InstructionBase* instr_node);
void RecordMemcpyD2H(InstructionBase* instr_node);
......@@ -237,6 +195,12 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void SolvePersisableVarNames();
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
std::unique_ptr<::ir::Program> ir_program_{nullptr};
......
......@@ -222,11 +222,6 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
}
}
FetchList ProgramInterpreter::BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch) {
return {};
}
void ProgramInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
......@@ -348,6 +343,18 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
return stream_analyzer_;
}
const interpreter::NewIrDependencyBuilder&
ProgramInterpreter::GetNewIrDependencyBuilder() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
const interpreter::NewIrStreamAnalyzer&
ProgramInterpreter::GetNewIrStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
bool ProgramInterpreter::IsSharedResultsBuild() const {
return is_shared_results_build_;
}
......
......@@ -48,10 +48,6 @@ class ProgramInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
......@@ -63,6 +59,12 @@ class ProgramInterpreter : public InterpreterBaseImpl {
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
......
......@@ -1283,25 +1283,13 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_api,
/**
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_beta_run
* Since Version: 2.6.0
* Value Range: bool, default=true
* Example:
* Note: If Ture, executor will use new IR and run in beta version.
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_beta_run,
true,
"Enable new IR in executor");
/**
* Using new IR in executor FLAG
* Name: enable_new_ir_in_executor_loop_run
* Name: enable_new_ir_in_executor_trace_run
* Since Version: 2.6.0
* Value Range: bool, default=false
* Example:
* Note: If Ture, executor will use new IR and run in beta version by for loop
* Note: If Ture, executor will use new IR and run in beta version by for trace
* version.
*/
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_loop_run,
PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run,
false,
"Enable new IR in executor");
......@@ -77,7 +77,7 @@ TEST(StandaloneExecutor, run) {
std::string out_name = os.str() + "_inner_var_2";
test_core.SetSkipGcVars({out_name});
test_core.BetaRun({});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
......@@ -118,7 +118,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) {
std::string out_name = os.str() + "_inner_var_0";
test_core.SetSkipGcVars({out_name});
test_core.BetaRun({});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
......
......@@ -76,7 +76,7 @@ TEST(VJP, TanhBackwardTest) {
std::string prefix_str = os.str();
test_core.SetSkipGcVars(
{prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"});
test_core.BetaRun({});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>()
......@@ -130,7 +130,7 @@ TEST(VJP, Tanh_BackwardTest) {
std::string prefix_str = os.str();
test_core.SetSkipGcVars(
{prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"});
test_core.BetaRun({});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
......@@ -184,7 +184,7 @@ TEST(VJP, MeanBackwardTest) {
std::string prefix_str = os.str();
test_core.SetSkipGcVars(
{prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"});
test_core.BetaRun({});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_1")->Get<phi::DenseTensor>()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册