未验证 提交 06e9233f 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine code (#56089)

* refine code

* refine code

* refine code
上级 4788971d
...@@ -94,19 +94,8 @@ class InterpreterBaseImpl { ...@@ -94,19 +94,8 @@ class InterpreterBaseImpl {
virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0; virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
virtual const interpreter::DependencyBuilder& GetDependencyBuilder()
const = 0;
virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0; virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;
virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;
virtual const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const = 0;
virtual const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const = 0;
virtual bool IsSharedResultsBuild() const = 0; virtual bool IsSharedResultsBuild() const = 0;
}; };
......
...@@ -201,26 +201,20 @@ void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { ...@@ -201,26 +201,20 @@ void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
} }
void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
if (is_shared_results_build_ || !src.IsSharedResultsBuild()) { const NewIRInterpreter& impl = dynamic_cast<const NewIRInterpreter&>(src);
if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) {
return; return;
} }
// share op dependency // share op dependency
ir_dependency_builder_.ShareDependencyFrom(src.GetNewIrDependencyBuilder()); ir_dependency_builder_.ShareDependencyFrom(impl.GetNewIrDependencyBuilder());
dependecy_count_ = src.GetDependencyCount(); dependecy_count_ = impl.GetDependencyCount();
// share event analysis // share event analysis
ir_stream_analyzer_.ShareEventInfoFrom(src.GetNewIrStreamAnalyzer()); ir_stream_analyzer_.ShareEventInfoFrom(impl.GetNewIrStreamAnalyzer());
is_shared_results_build_ = true; is_shared_results_build_ = true;
VLOG(8) << "Share Build Results from InterpreterCore(" << &src VLOG(8) << "Share Build Results from InterpreterCore(" << &impl
<< ") to InterpreterCore(" << this << ")"; << ") to InterpreterCore(" << this << ")";
} }
// op dependences
const interpreter::DependencyBuilder& NewIRInterpreter::GetDependencyBuilder()
const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in NewIRInterpreter."));
}
const interpreter::NewIrDependencyBuilder& const interpreter::NewIrDependencyBuilder&
NewIRInterpreter::GetNewIrDependencyBuilder() const { NewIRInterpreter::GetNewIrDependencyBuilder() const {
return ir_dependency_builder_; return ir_dependency_builder_;
...@@ -231,11 +225,6 @@ std::shared_ptr<std::vector<size_t>> NewIRInterpreter::GetDependencyCount() ...@@ -231,11 +225,6 @@ std::shared_ptr<std::vector<size_t>> NewIRInterpreter::GetDependencyCount()
return dependecy_count_; return dependecy_count_;
} }
const interpreter::StreamAnalyzer& NewIRInterpreter::GetStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetStreamAnalyzer is not implemented in NewIRInterpreter."));
}
const interpreter::NewIrStreamAnalyzer& const interpreter::NewIrStreamAnalyzer&
NewIRInterpreter::GetNewIrStreamAnalyzer() const { NewIRInterpreter::GetNewIrStreamAnalyzer() const {
return ir_stream_analyzer_; return ir_stream_analyzer_;
......
...@@ -53,13 +53,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -53,13 +53,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override; void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
// op dependences
const interpreter::DependencyBuilder& GetDependencyBuilder() const override;
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override; std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
bool IsSharedResultsBuild() const override; bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override; void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
...@@ -195,11 +190,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -195,11 +190,9 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void SolvePersisableVarNames(); void SolvePersisableVarNames();
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder() const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder() const;
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer() const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer() const;
const override;
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
......
...@@ -287,16 +287,17 @@ void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { ...@@ -287,16 +287,17 @@ void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
} }
void ProgramInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { void ProgramInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
if (is_shared_results_build_ || !src.IsSharedResultsBuild()) { const ProgramInterpreter& impl = dynamic_cast<const ProgramInterpreter&>(src);
if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) {
return; return;
} }
// share op dependency // share op dependency
dependency_builder_.ShareDependencyFrom(src.GetDependencyBuilder()); dependency_builder_.ShareDependencyFrom(impl.GetDependencyBuilder());
dependecy_count_ = src.GetDependencyCount(); dependecy_count_ = impl.GetDependencyCount();
// share event analysis // share event analysis
stream_analyzer_.ShareEventInfoFrom(src.GetStreamAnalyzer()); stream_analyzer_.ShareEventInfoFrom(impl.GetStreamAnalyzer());
is_shared_results_build_ = true; is_shared_results_build_ = true;
VLOG(8) << "Share Build Results from InterpreterCore(" << &src VLOG(8) << "Share Build Results from InterpreterCore(" << &impl
<< ") to InterpreterCore(" << this << ")"; << ") to InterpreterCore(" << this << ")";
} }
...@@ -343,18 +344,6 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer() ...@@ -343,18 +344,6 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
return stream_analyzer_; return stream_analyzer_;
} }
const interpreter::NewIrDependencyBuilder&
ProgramInterpreter::GetNewIrDependencyBuilder() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
const interpreter::NewIrStreamAnalyzer&
ProgramInterpreter::GetNewIrStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
bool ProgramInterpreter::IsSharedResultsBuild() const { bool ProgramInterpreter::IsSharedResultsBuild() const {
return is_shared_results_build_; return is_shared_results_build_;
} }
......
...@@ -53,17 +53,11 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -53,17 +53,11 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override; void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
// op dependences // op dependences
const interpreter::DependencyBuilder& GetDependencyBuilder() const override; const interpreter::DependencyBuilder& GetDependencyBuilder() const;
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override; std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override; const interpreter::StreamAnalyzer& GetStreamAnalyzer() const;
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
bool IsSharedResultsBuild() const override; bool IsSharedResultsBuild() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册