“1b7f60100818d402ba50924bf12a65bee4e5d96d”上不存在“...LowPrecision/ResNet50_Slim/resnet50_web_service.py”
未验证 提交 f38e126e 编写于 作者: A Aurelius84 提交者: GitHub

[NewExe]Polish InterpreterCore with PImpl and Derived ProgramInterpreter and...

[NewExe]Polish InterpreterCore with PImpl and Derived ProgramInterpreter and NewIRInterpreter (#54651)

* [NewExe]Polish InterpreterCore with PImpl

fix code style

add std::move

* fix conflict

* fix typo

* fix typo
上级 52e2a557
...@@ -2,8 +2,9 @@ add_subdirectory(garbage_collector) ...@@ -2,8 +2,9 @@ add_subdirectory(garbage_collector)
add_subdirectory(interpreter) add_subdirectory(interpreter)
add_subdirectory(workqueue) add_subdirectory(workqueue)
set(STANDALONE_EXECUTOR_SRCS feed_fetch_utils.cc interpretercore.cc set(STANDALONE_EXECUTOR_SRCS
new_executor_defs.cc standalone_executor.cc) feed_fetch_utils.cc interpretercore.cc new_executor_defs.cc
standalone_executor.cc program_interpreter.cc new_ir_interpreter.cc)
set(STANDALONE_EXECUTOR_DEPS set(STANDALONE_EXECUTOR_DEPS
interpreter interpreter
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/flags.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/backends/device_manager.h"
DECLARE_bool(new_executor_serial_run);
DECLARE_bool(new_executor_static_build);
DECLARE_bool(new_executor_use_inplace);
DECLARE_bool(new_executor_use_local_scope);
PHI_DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
DECLARE_uint64(executor_log_deps_every_microseconds);
PHI_DECLARE_bool(new_executor_use_cuda_graph);
PHI_DECLARE_bool(enable_new_ir_in_executor);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PHI_DECLARE_bool(sync_nccl_allreduce);
#endif
constexpr const char* kExceptionCaught = "ExceptionCaught";
constexpr const char* kTaskCompletion = "TaskCompletion";
namespace paddle {
namespace framework {
using HookFunc = std::function<void(OperatorBase*, Scope*)>;
/// @brief InterpreterBaseImpl is a abstract Base Class and define necessary
/// interface with virtual keywords for Derived class.
/// TODO(Aurelius84): Clean unnecessary interface to keep cohesion.
class InterpreterBaseImpl {
public:
virtual ~InterpreterBaseImpl() = default;
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) = 0;
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;
virtual void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) = 0;
virtual void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) = 0;
virtual const std::set<std::string>& JitInputVars() const = 0;
virtual void SetJitInputVars(const std::set<std::string>& jit_input_vars) = 0;
virtual const VariableScope* GetVariableScope() const = 0;
virtual void reset_scope(Scope* new_scope) = 0;
virtual const platform::Place& GetPlace() const = 0;
virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
};
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
} // namespace framework
} // namespace paddle
...@@ -12,56 +12,41 @@ ...@@ -12,56 +12,41 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
#include <map>
#include <queue>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
DECLARE_bool(new_executor_use_local_scope); DECLARE_bool(new_executor_use_local_scope);
namespace ir {
class Program;
} // namespace ir
namespace ir {
class Program;
} // namespace ir
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InterpreterBaseImpl;
class InterpreterCore { class InterpreterCore {
using ExecutionConfig = interpreter::ExecutionConfig; using ExecutionConfig = interpreter::ExecutionConfig;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>; using HookFunc = std::function<void(OperatorBase*, Scope*)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
public: public:
InterpreterCore(const platform::Place& place, InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
Scope* scope, Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig()); const ExecutionConfig& execution_config = ExecutionConfig());
// This constructor is for New IR.
InterpreterCore(const platform::Place& place, InterpreterCore(const platform::Place& place,
const BlockDesc& block,
Scope* scope,
std::unique_ptr<::ir::Program> ir_prog, std::unique_ptr<::ir::Program> ir_prog,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig()); const ExecutionConfig& execution_config = ExecutionConfig());
~InterpreterCore(); ~InterpreterCore();
const InterpreterBaseImpl* Impl() const { return impl_.get(); }
paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors); const std::vector<phi::DenseTensor>& feed_tensors);
...@@ -83,127 +68,14 @@ class InterpreterCore { ...@@ -83,127 +68,14 @@ class InterpreterCore {
void reset_scope(Scope* new_scope); void reset_scope(Scope* new_scope);
const platform::Place& GetPlace() const { return place_; } const platform::Place& GetPlace() const;
using HookFunc = std::function<void(OperatorBase*, Scope*)>; void SetOutputHooks(const std::vector<HookFunc>& hookfuncs);
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) {
hookfuncs_ = hookfuncs;
}
private: private:
DISABLE_COPY_AND_ASSIGN(InterpreterCore); DISABLE_COPY_AND_ASSIGN(InterpreterCore);
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
// inplace
void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc
void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
// workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
// scope
bool HasLocalScope() const;
// For log and debug
std::string GetDepsString() const;
private:
bool is_build_{false};
bool static_build_{false};
const platform::Place place_;
const BlockDesc& block_; // not owned
interpreter::DependencyBuilder dependency_builder_;
interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
// from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0};
ExecutionConfig execution_config_;
VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned
EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
// last_live_ops_[i] contains the id of operators that last access the i-th
// var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
// used for Trace
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
std::vector<HookFunc> hookfuncs_;
// The next only for new IR
std::unique_ptr<::ir::Program> ir_program_{nullptr};
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_; std::unique_ptr<InterpreterBaseImpl> impl_;
}; };
} // namespace framework } // namespace framework
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
namespace ir {
class Program;
} // namespace ir
namespace paddle {
namespace framework {
class NewIRInterpreter : public InterpreterBaseImpl {
using ExecutionConfig = interpreter::ExecutionConfig;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
public:
NewIRInterpreter(const platform::Place& place,
std::unique_ptr<::ir::Program> ir_prog,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());
~NewIRInterpreter();
paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) override;
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) override;
const std::set<std::string>& JitInputVars() const override;
void SetJitInputVars(const std::set<std::string>& jit_input_vars) override;
const VariableScope* GetVariableScope() const override;
void reset_scope(Scope* new_scope) override;
const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
hookfuncs_ = hookfuncs;
}
private:
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
// inplace
void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc
void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
// workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
// scope
bool HasLocalScope() const;
// For log and debug
std::string GetDepsString() const;
bool is_build_{false};
bool static_build_{false};
const platform::Place place_;
interpreter::DependencyBuilder dependency_builder_;
interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
// from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0};
ExecutionConfig execution_config_;
VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned
EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
// last_live_ops_[i] contains the id of operators that last access the i-th
// var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
// used for Trace
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
std::vector<HookFunc> hookfuncs_;
std::unique_ptr<::ir::Program> ir_program_{nullptr};
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
namespace paddle {
namespace framework {
///
/// \brief Derived Class to interpret the instructions transformed
/// from legacy ProgramDesc.
///
class ProgramInterpreter : public InterpreterBaseImpl {
using ExecutionConfig = interpreter::ExecutionConfig;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
public:
ProgramInterpreter(
const platform::Place& place,
const BlockDesc& block,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());
~ProgramInterpreter();
paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) override;
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) override;
const std::set<std::string>& JitInputVars() const override;
void SetJitInputVars(const std::set<std::string>& jit_input_vars) override;
const VariableScope* GetVariableScope() const override;
void reset_scope(Scope* new_scope) override;
const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
hookfuncs_ = hookfuncs;
}
private:
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildOperatorDependences();
void BuildAndCacheInstructionCtx(Instruction* instr_node);
void BuildSkipShareLoDInfo();
void UpdateSyncOpNum();
void AnalyseExecuteOrderForTrace();
// inplace
void BuildInplace();
bool BuildInplaceCheckVarIsOnlyInput(
const std::vector<std::vector<size_t>>& input_var2op, size_t var_index);
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
// cuda graph
void CheckCUDAGraphBeforeRun(const std::vector<std::string>& feed_names);
void PrepareForCUDAGraphCapture();
// execution
void RunImpl();
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr);
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
SchedulingQueue* reserved_next_ops);
void RunOperator(const Instruction& instr_node);
// Trace
void TraceInstructionList(const std::vector<Instruction>& vec_instr);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
bool prepare_feed);
void RecordMemcpyD2H(const Instruction& instr_node);
// gc
void RecordStreamForGC(const Instruction& instr);
void CheckGC(const Instruction& instr);
void ClearLoDTensorArrayInLocalScope();
// workqueue
std::shared_ptr<interpreter::AsyncWorkQueue> GetWorkQueue();
// scope
bool HasLocalScope() const;
// For log and debug
std::string GetDepsString() const;
bool is_build_{false};
bool static_build_{false};
const platform::Place place_;
const BlockDesc& block_; // not owned
interpreter::DependencyBuilder dependency_builder_;
interpreter::StreamAnalyzer stream_analyzer_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
// from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::atomic<size_t> unfinished_op_number_{0};
ExecutionConfig execution_config_;
VariableScope var_scope_;
Scope* local_scope_{nullptr}; // not owned
EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
std::shared_ptr<EventsWaiter::EventNotifier> completion_notifier_{nullptr};
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
// last_live_ops_[i] contains the id of operators that last access the i-th
// var
std::map<size_t, std::set<size_t>> last_live_ops_;
// dependecy_count_[i] contains the number of dependencies that the i-th op
// need to wait
std::vector<size_t> dependecy_count_;
std::vector<std::shared_ptr<interpreter::OpDepInfo>> deps_;
std::vector<std::shared_ptr<interpreter::VarRefInfo>> refs_;
// used for Trace
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
std::vector<HookFunc> hookfuncs_;
};
} // namespace framework
} // namespace paddle
...@@ -71,12 +71,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -71,12 +71,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
auto base_progrm = paddle::TranslateLegacyProgramToProgram(*program); auto base_progrm = paddle::TranslateLegacyProgramToProgram(*program);
auto kernel_program = auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_progrm.get()); paddle::dialect::PdOpLowerToKernelPass(base_progrm.get());
interpretercores_.emplace_back( interpretercores_.emplace_back(std::make_unique<InterpreterCore>(
std::make_unique<InterpreterCore>(place_, place_, std::move(kernel_program), scope_, execution_config));
program->Block(0),
scope_,
std::move(kernel_program),
execution_config));
} else { } else {
interpretercores_.emplace_back( interpretercores_.emplace_back(
std::make_unique<InterpreterCore>(place_, std::make_unique<InterpreterCore>(place_,
......
...@@ -69,8 +69,7 @@ TEST(StandaloneExecutor, run) { ...@@ -69,8 +69,7 @@ TEST(StandaloneExecutor, run) {
Scope scope; Scope scope;
ProgramDesc prog_desc; ProgramDesc prog_desc;
InterpreterCore test_core( InterpreterCore test_core(place, std::move(kernel_program), &scope);
place, prog_desc.Block(0), &scope, std::move(kernel_program));
test_core.Run({}); test_core.Run({});
...@@ -139,8 +138,7 @@ TEST(StandaloneExecutor, run_2) { ...@@ -139,8 +138,7 @@ TEST(StandaloneExecutor, run_2) {
ProgramDesc prog_desc; ProgramDesc prog_desc;
InterpreterCore test_core( InterpreterCore test_core(place, std::move(kernel_program), &scope);
place, prog_desc.Block(0), &scope, std::move(kernel_program));
test_core.Run({}); test_core.Run({});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册