// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" namespace paddle { namespace framework { /// /// \brief Derived Class to interpret the instructions transformed /// from legacy ProgramDesc. /// class ProgramInterpreter : public InterpreterBaseImpl { using ExecutionConfig = interpreter::ExecutionConfig; using InstructionSchedulingPriorityLess = std::function; using SchedulingQueue = std::priority_queue, InstructionSchedulingPriorityLess>; public: ProgramInterpreter( const platform::Place& place, const BlockDesc& block, Scope* scope, const ExecutionConfig& execution_config = ExecutionConfig()); ~ProgramInterpreter(); paddle::framework::FetchList Run( const std::vector& feed_names, const std::vector& feed_tensors) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; void SetCopyProgram(std::shared_ptr prog) override; void SetSkipGcVars(const std::set& skip_gc_vars) override; const std::set& JitInputVars() const override; void SetJitInputVars(const std::set& jit_input_vars) override; const VariableScope* GetVariableScope() const override; void reset_scope(Scope* new_scope) override; const Scope* local_scope() const override; const platform::Place& GetPlace() const override { return place_; } void SetOutputHooks(const std::vector& hookfuncs) override { hookfuncs_ = hookfuncs; } private: // build graph void Convert(std::vector* op_func_nodes); void BuildOperatorDependences(); void BuildAndCacheInstructionCtx(Instruction* instr_node); void BuildSkipShareLoDInfo(); void UpdateSyncOpNum(); void AnalyseExecuteOrderForTrace(); // inplace void BuildInplace(); bool BuildInplaceCheckVarIsOnlyInput( const std::vector>& input_var2op, size_t var_index); void SetFeedVarsInplaceSkip(const std::vector& feed_names); // cuda graph void CheckCUDAGraphBeforeRun(const std::vector& feed_names); void PrepareForCUDAGraphCapture(); // execution void RunImpl(); void ExecuteInstructionList(const std::vector& vec_instr); void RunInstructionAsync(size_t instr_id); void RunInstruction(const Instruction& instr_node); void RunNextInstructions(const Instruction& instr_id, SchedulingQueue* reserved_next_ops); void RunOperator(const Instruction& instr_node); // Trace void TraceInstructionList(const std::vector& vec_instr); // only used when program contains no feed op void Prepare(const std::vector& feed_names, const std::vector& feed_tensors, bool prepare_feed); void RecordMemcpyD2H(const Instruction& instr_node); // gc void RecordStreamForGC(const Instruction& instr); void CheckGC(const Instruction& instr); void ClearLoDTensorArrayInLocalScope(); // workqueue std::shared_ptr GetWorkQueue(); // scope bool HasLocalScope() const; // For log and debug std::string GetDepsString() const; bool is_build_{false}; bool static_build_{false}; const platform::Place place_; const BlockDesc& block_; // not owned interpreter::DependencyBuilder dependency_builder_; interpreter::StreamAnalyzer stream_analyzer_; // NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will // copy a new program and block, the copy_program_ here is used to // hold the program, otherwise block_ maybe not valid after the // new program is deleted. std::shared_ptr copy_program_{nullptr}; // from variable scope std::vector var_list_; std::map name2id_; std::vector vec_meta_info_; std::vector vec_instruction_; // deconstruct before OpFuncNode std::atomic unfinished_op_number_{0}; ExecutionConfig execution_config_; VariableScope var_scope_; Scope* local_scope_{nullptr}; // not owned EventsWaiter main_thread_blocker_; std::shared_ptr async_work_queue_; details::ExceptionHolder exception_holder_; std::shared_ptr exception_notifier_{nullptr}; std::shared_ptr completion_notifier_{nullptr}; std::unique_ptr gc_; // last_live_ops_[i] contains the id of operators that last access the i-th // var std::map> last_live_ops_; // dependecy_count_[i] contains the number of dependencies that the i-th op // need to wait std::vector dependecy_count_; std::vector> deps_; std::vector> refs_; // used for Trace int64_t sync_op_num_{-1}; std::vector trace_execute_order_; InstructionSchedulingPriorityLess instruction_scheduling_priority_less; std::vector hookfuncs_; }; } // namespace framework } // namespace paddle