// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" namespace paddle { namespace framework { class Scope; namespace details { class ThreadedSSAGraphExecutor : public SSAGraphExecutor { public: ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; ~ThreadedSSAGraphExecutor() {} private: void RunOp(BlockingQueue *ready_var_q, details::OpHandleBase *op); private: std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; platform::DeviceContextPool fetch_ctxs_; std::unique_ptr exception_; std::atomic running_ops_; void InsertPendingOp(std::unordered_map *pending_ops, OpHandleBase *op_instance) const; void InsertPendingVar(std::unordered_set *pending_vars, BlockingQueue *ready_vars, VarHandleBase *var) const; void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, BlockingQueue *ready_vars, FeedFetchList *fetch_data); private: ExecutionStrategy strategy_; }; } // namespace details } // namespace framework } // namespace paddle