未验证 提交 06e9233f 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine code (#56089)

* refine code

* refine code

* refine code
上级 4788971d
......@@ -94,19 +94,8 @@ class InterpreterBaseImpl {
virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
virtual const interpreter::DependencyBuilder& GetDependencyBuilder()
const = 0;
virtual std::shared_ptr<std::vector<size_t>> GetDependencyCount() const = 0;
virtual const interpreter::StreamAnalyzer& GetStreamAnalyzer() const = 0;
virtual const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const = 0;
virtual const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const = 0;
virtual bool IsSharedResultsBuild() const = 0;
};
......
......@@ -201,26 +201,20 @@ void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
}
void NewIRInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
if (is_shared_results_build_ || !src.IsSharedResultsBuild()) {
const NewIRInterpreter& impl = dynamic_cast<const NewIRInterpreter&>(src);
if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) {
return;
}
// share op dependency
ir_dependency_builder_.ShareDependencyFrom(src.GetNewIrDependencyBuilder());
dependecy_count_ = src.GetDependencyCount();
ir_dependency_builder_.ShareDependencyFrom(impl.GetNewIrDependencyBuilder());
dependecy_count_ = impl.GetDependencyCount();
// share event analysis
ir_stream_analyzer_.ShareEventInfoFrom(src.GetNewIrStreamAnalyzer());
ir_stream_analyzer_.ShareEventInfoFrom(impl.GetNewIrStreamAnalyzer());
is_shared_results_build_ = true;
VLOG(8) << "Share Build Results from InterpreterCore(" << &src
VLOG(8) << "Share Build Results from InterpreterCore(" << &impl
<< ") to InterpreterCore(" << this << ")";
}
// op dependences
const interpreter::DependencyBuilder& NewIRInterpreter::GetDependencyBuilder()
const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in NewIRInterpreter."));
}
const interpreter::NewIrDependencyBuilder&
NewIRInterpreter::GetNewIrDependencyBuilder() const {
return ir_dependency_builder_;
......@@ -231,11 +225,6 @@ std::shared_ptr<std::vector<size_t>> NewIRInterpreter::GetDependencyCount()
return dependecy_count_;
}
const interpreter::StreamAnalyzer& NewIRInterpreter::GetStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetStreamAnalyzer is not implemented in NewIRInterpreter."));
}
const interpreter::NewIrStreamAnalyzer&
NewIRInterpreter::GetNewIrStreamAnalyzer() const {
return ir_stream_analyzer_;
......
......@@ -53,13 +53,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
// op dependences
const interpreter::DependencyBuilder& GetDependencyBuilder() const override;
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
bool IsSharedResultsBuild() const override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
......@@ -195,11 +190,9 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void SolvePersisableVarNames();
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder() const;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer() const;
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
......
......@@ -287,16 +287,17 @@ void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
}
void ProgramInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
if (is_shared_results_build_ || !src.IsSharedResultsBuild()) {
const ProgramInterpreter& impl = dynamic_cast<const ProgramInterpreter&>(src);
if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) {
return;
}
// share op dependency
dependency_builder_.ShareDependencyFrom(src.GetDependencyBuilder());
dependecy_count_ = src.GetDependencyCount();
dependency_builder_.ShareDependencyFrom(impl.GetDependencyBuilder());
dependecy_count_ = impl.GetDependencyCount();
// share event analysis
stream_analyzer_.ShareEventInfoFrom(src.GetStreamAnalyzer());
stream_analyzer_.ShareEventInfoFrom(impl.GetStreamAnalyzer());
is_shared_results_build_ = true;
VLOG(8) << "Share Build Results from InterpreterCore(" << &src
VLOG(8) << "Share Build Results from InterpreterCore(" << &impl
<< ") to InterpreterCore(" << this << ")";
}
......@@ -343,18 +344,6 @@ const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
return stream_analyzer_;
}
const interpreter::NewIrDependencyBuilder&
ProgramInterpreter::GetNewIrDependencyBuilder() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
const interpreter::NewIrStreamAnalyzer&
ProgramInterpreter::GetNewIrStreamAnalyzer() const {
PADDLE_THROW(platform::errors::Unimplemented(
"GetDependencyBuilder is not implemented in ProgramInterpreter."));
}
bool ProgramInterpreter::IsSharedResultsBuild() const {
return is_shared_results_build_;
}
......
......@@ -53,17 +53,11 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override;
// op dependences
const interpreter::DependencyBuilder& GetDependencyBuilder() const override;
const interpreter::DependencyBuilder& GetDependencyBuilder() const;
std::shared_ptr<std::vector<size_t>> GetDependencyCount() const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const override;
const interpreter::NewIrDependencyBuilder& GetNewIrDependencyBuilder()
const override;
const interpreter::NewIrStreamAnalyzer& GetNewIrStreamAnalyzer()
const override;
const interpreter::StreamAnalyzer& GetStreamAnalyzer() const;
bool IsSharedResultsBuild() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册