diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index c69f148297aa01c4741afa3d50f11f9fb02b3b6f..941f1c673ab54fb36a9a9c26919415cec6f2fb60 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( bootstrap_ops_.emplace_back(op); } } - + PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators."); PrepareAtomicOpDeps(); } @@ -52,26 +52,85 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( std::unique_ptr>> op_deps = atomic_op_deps_.get(); PrepareAtomicOpDeps(); + size_t num_ops = op_deps->size(); paddle::framework::FeedFetchList fetches; fetches.resize(fetch_tensors.size()); std::unordered_map> fetched_vars; - std::vector fetch_ops; + std::vector fetch_ops; std::vector ready_fetch_ops; + exception_.Clear(); + + InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(), + &fetch_ops, &ready_fetch_ops); + + if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) { + // If the num_threads is 1, we can record the order of operator's + // execution in the first iteration, and in subsequent iterations, + // run the recorded operators directly. This strategy could make the + // execution faster. + VLOG(3) << "Run the traced ops."; + RunTracedOps(traced_ops_); + RunTracedOps(fetch_ops); + if (exception_.IsCaught()) { + ExecutionFinal(&fetch_ops); + } + } else { + traced_ops_.clear(); + remaining_ = 0; + auto complete_q = std::make_shared>(); + for (auto op : bootstrap_ops_) { + RunOpAsync(op_deps.get(), op, complete_q); + } + for (auto op : ready_fetch_ops) { + RunOpAsync(op_deps.get(), op, complete_q); + } + + size_t num_complete = 0; + while (num_complete != op_deps->size()) { + size_t num_comp = complete_q->Pop(); + if (num_comp == -1UL) { + int remaining = 0; + while (true) { + remaining = remaining_; + if (remaining == 0) { + break; + } + for (int i = 0; i < remaining; ++i) { + complete_q->Pop(); + } + } + if (exception_.IsCaught()) { + ExecutionFinal(&fetch_ops); + } + } + num_complete += num_comp; + } + } + // Wait FetchOps. + ClearFetchOp(graph_, &fetch_ops); + return fetches; +} +void FastThreadedSSAGraphExecutor::InsertFetchOps( + const std::vector &fetch_tensors, FeedFetchList *fetches, + std::unordered_map> *fetched_vars, + std::unordered_map> *op_deps, + std::vector *fetch_ops, + std::vector *ready_fetch_ops) { for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->Get(details::kGraphVars)) { + for (auto &var_map : graph_->Get(kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { - fetched_vars[fetch_var_name].push_back(*it->second.rbegin()); + (*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin()); } } } for (size_t i = 0; i < fetch_tensors.size(); ++i) { - auto &var_name = fetch_tensors[i]; - auto fetched_var_it = fetched_vars.find(var_name); - PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), + auto &var_name = fetch_tensors.at(i); + auto fetched_var_it = fetched_vars->find(var_name); + PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(), "Cannot find fetched variable(%s).(Perhaps the main_program " "is not set to ParallelExecutor)", var_name); @@ -80,8 +139,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ir::Node *fetch_node = graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); - auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_); - fetch_ops.emplace_back(op); + auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_); + fetch_ops->emplace_back(op); for (auto &p : places_) { op->SetDeviceContext(p, fetch_ctxs_.Get(p)); @@ -94,55 +153,22 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( int dep = static_cast(op->NotReadyInputSize()); (*op_deps)[op] = dep; if (dep == 0) { - ready_fetch_ops.emplace_back(op); - } - } - - size_t num_complete = 0; - remaining_ = 0; - auto complete_q = std::make_shared>(); - for (auto op : bootstrap_ops_) { - RunOpAsync(op_deps.get(), op, complete_q); - } - for (auto op : ready_fetch_ops) { - RunOpAsync(op_deps.get(), op, complete_q); - } - while (num_complete != op_deps->size()) { - size_t num_comp = complete_q->Pop(); - if (num_comp == -1UL) { - int remaining = 0; - while (true) { - remaining = remaining_; - if (remaining == 0) { - break; - } - for (int i = 0; i < remaining; ++i) { - complete_q->Pop(); - } - } - if (exception_.IsCaught()) { - ClearFetchOp(graph_, &fetch_ops); - exception_.ReThrow(); - } + ready_fetch_ops->emplace_back(op); } - num_complete += num_comp; } - // Wait FetchOps. - ClearFetchOp(graph_, &fetch_ops); - return fetches; } bool FastThreadedSSAGraphExecutor::RunOp( OpHandleBase *op, const std::shared_ptr> &complete_q, size_t *complete) { - try { + RunOpSync(op); + if (LIKELY(!exception_.IsCaught())) { if (LIKELY(!strategy_.dry_run_)) { - op->Run(strategy_.use_cuda_); + RecordOps(op); } ++(*complete); return true; - } catch (...) { - exception_.Catch(std::current_exception()); + } else { --remaining_; complete_q->Push(-1UL); return false; @@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( complete_q->Push(complete); }); } + void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { atomic_op_deps_ = prepare_pool_.enqueue([&] { auto *op_deps = new std::unordered_map>; @@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { } const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } + +void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { + if (strategy_.num_threads_ == 1 && !dynamic_cast(op)) { + traced_ops_.emplace_back(op); + } +} + +void FastThreadedSSAGraphExecutor::ExecutionFinal( + std::vector *fetch_ops) { + VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it"; + ClearFetchOp(graph_, fetch_ops); + exception_.ReThrow(); +} + +void FastThreadedSSAGraphExecutor::RunTracedOps( + const std::vector &traced_ops) { + for (auto &op : traced_ops) { + if (exception_.IsCaught()) { + return; + } + RunOpSync(op); + } +} + +void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { + try { + if (VLOG_IS_ON(10)) { + VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); + } + if (LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_cuda_); + } + VLOG(10) << op << " " << op->Name() << " Done "; + } catch (...) { + exception_.Catch(std::current_exception()); + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index 234da5b9254bcdfb4682301c679be67f99cda280..d88e5bbaa97419c6e5229deaa16fbcfa922432d0 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ::ThreadPool pool_; ::ThreadPool prepare_pool_; + std::vector traced_ops_; + bool RunOp(OpHandleBase *op, const std::shared_ptr> &complete_q, size_t *complete); @@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { const std::shared_ptr> &complete_q); void PrepareAtomicOpDeps(); + + inline void RecordOps(OpHandleBase *op); + + inline void ExecutionFinal(std::vector *fetch_ops); + + inline void RunOpSync(OpHandleBase *op); + + void RunTracedOps(const std::vector &traced_ops); + + void InsertFetchOps( + const std::vector &fetch_tensors, FeedFetchList *fetches, + std::unordered_map> + *fetched_vars, + std::unordered_map> *op_deps, + std::vector *fetch_ops, + std::vector *ready_fetch_ops); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc index af2cbd5c876fdd7c27cd679f7e9412d1b0604ecc..4f1e44ca26cb65468da6eded74653f34dbf00336 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -19,10 +19,13 @@ namespace framework { namespace details { SSAGraphExecutor::~SSAGraphExecutor() {} -void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops) { +void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops) { if (fetch_ops->empty()) return; for (auto& op : *fetch_ops) { + PADDLE_ENFORCE_NOT_NULL( + dynamic_cast(op), + "The input ops of ClearFetchOp function should be FetchOpHandle."); for (auto& out_var : op->Node()->outputs) { graph->RemoveNode(out_var); } diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 860eaa25b58e4579ad792ff18618de3b90707e8d..2454ec2b27d9d2060f28b8d6cea0ce49fe347433 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -38,7 +38,7 @@ class SSAGraphExecutor { virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; -void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops); +void ClearFetchOp(ir::Graph* graph, std::vector* fetch_ops); } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 67246a4dd448b0ce2f115d6438c5fdd6cc39ca6d..ac62f1dd83397a15830eae02c0ba00920a90dcfd 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -53,74 +53,84 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); std::unique_ptr op_deps = op_deps_futures_.get(); CopyOpDeps(); + VLOG(10) << "ThreadedSSAGraphExecutor::Run"; std::shared_ptr> ready_vars( new BlockingQueue); auto &pending_ops = op_deps->pending_ops_; auto &pending_vars = op_deps->pending_vars_; auto &ready_ops = op_deps->ready_ops_; - - // For ops (e.g. nccl_all_reduce) that need to coordinate multiple - // streams from multiple GPUs, it's faster to buffer them and schedule - // together since we currently cannot overlap computation and memcpy streams. - // Should revisit it if overlapping is available. - std::unordered_set delayed_ops; + size_t num_ops = op_deps->num_ops_; // Step 2. Insert FetchOps - std::vector fetch_ops; + std::vector fetch_ops; std::unordered_set fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops, &pending_ops, &pending_vars, &fetch_data); - auto run_all_ops = [&](std::unordered_set &set) { - for (auto *op : set) { - RunOp(ready_vars, op); - } - set.clear(); - }; - // Clean run context - run_op_futures_.clear(); exception_holder_.Clear(); event.reset(nullptr); + // Step 3. Execution - while (!pending_vars.empty()) { - // 1. Run All Ready ops - // Keep loop until all vars are ready. - run_all_ops(ready_ops); - - // 2. Find ready variable - bool timeout; - auto cur_ready_vars = ready_vars->PopAll(1, &timeout); - if (timeout) { - if (exception_holder_.IsCaught()) { - VLOG(3) << "caught exception " << exception_holder_.Type() - << ", rethrow it"; + if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) { + // If the num_threads is 1, we can record the order of operator's + // execution in the first iteration, and in subsequent iterations, + // run the recorded operators directly. This strategy could make the + // execution faster. + VLOG(3) << "Run the traced ops."; + RunTracedOps(traced_ops_); + RunTracedOps(fetch_ops); + if (exception_holder_.IsCaught()) { + ExecutionFinal(&fetch_ops); + } + } else { + traced_ops_.clear(); + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { + RunOp(ready_vars, op); + } + set.clear(); + }; + // Clean run context + run_op_futures_.clear(); + + while (!pending_vars.empty()) { + // 1. Run All Ready ops + // Keep loop until all vars are ready. + run_all_ops(ready_ops); + + // 2. Find ready variable + bool timeout; + auto cur_ready_vars = ready_vars->PopAll(1, &timeout); + if (timeout) { for (auto &run_op_future : run_op_futures_) { run_op_future.wait(); } - ClearFetchOp(graph_, &fetch_ops); - exception_holder_.ReThrow(); - } else { - continue; + if (exception_holder_.IsCaught()) { + ExecutionFinal(&fetch_ops); + } else { + continue; + } } - } - // 3. Remove the dependency of ready_var. - // Find the ready_ops after the ready_var. - for (auto ready_var : cur_ready_vars) { - pending_vars.erase(ready_var); - for (auto *op : ready_var->PendingOps()) { - auto &deps = pending_ops[op]; - --deps; - if (deps == 0) { - ready_ops.insert(op); + // 3. Remove the dependency of ready_var. + // Find the ready_ops after the ready_var. + for (auto ready_var : cur_ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->PendingOps()) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + ready_ops.insert(op); + } } } } + PADDLE_ENFORCE(ready_ops.empty()); } - PADDLE_ENFORCE(ready_ops.empty()); + // Wait FetchOps. ClearFetchOp(graph_, &fetch_ops); @@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, - std::vector *fetch_ops, + std::vector *fetch_ops, std::unordered_set *fetch_dependencies, std::unordered_set *ready_ops, std::unordered_map *pending_ops, @@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() { InsertPendingOp(&pending_ops, op); } } + op_deps_->num_ops_ = ready_ops.size() + pending_ops.size(); + PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators."); + for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); for (auto *op : ready_var->PendingOps()) { @@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() { op_deps_->pending_vars_.end()); op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(), op_deps_->ready_ops_.end()); + op_deps->num_ops_ = op_deps_->num_ops_; return std::unique_ptr(op_deps); }); } @@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp( const std::shared_ptr> &ready_var_q, details::OpHandleBase *op) { auto op_run = [ready_var_q, op, this] { + RunOpSync(op); try { - if (VLOG_IS_ON(10)) { - VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); - } - if (LIKELY(!strategy_.dry_run_)) { - op->Run(strategy_.use_cuda_); - } - VLOG(10) << op << " " << op->Name() << " Done "; ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << " Signal posted"; } catch (...) { exception_holder_.Catch(std::current_exception()); } }; + if (pool_) { run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } + + RecordOps(op); +} + +void ThreadedSSAGraphExecutor::RunTracedOps( + const std::vector &traced_ops) { + for (auto &op : traced_ops) { + if (exception_holder_.IsCaught()) { + return; + } + RunOpSync(op); + } +} + +void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { + try { + if (VLOG_IS_ON(10)) { + VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); + } + if (LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_cuda_); + } + VLOG(10) << op << " " << op->Name() << " Done "; + } catch (...) { + exception_holder_.Catch(std::current_exception()); + } +} + +void ThreadedSSAGraphExecutor::ExecutionFinal( + std::vector *fetch_ops) { + VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it"; + ClearFetchOp(graph_, fetch_ops); + exception_holder_.ReThrow(); +} + +void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) { + if (strategy_.num_threads_ == 1 && !dynamic_cast(op)) { + traced_ops_.emplace_back(op); + } } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 8c026057b480fbc40b7b8f12d8e6b8e54195a141..6c1fb1c6c0a7b55cee89986c00bf650542520355 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -44,6 +44,7 @@ struct OpDependentData { std::unordered_map pending_ops_; std::unordered_set pending_vars_; std::unordered_set ready_ops_; + size_t num_ops_{0}; }; class ThreadedSSAGraphExecutor : public SSAGraphExecutor { @@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { std::list> run_op_futures_; ::ThreadPool prepare_pool_; std::unique_ptr<::ThreadPool> pool_; + std::vector traced_ops_; void InsertPendingOp(std::unordered_map *pending_ops, OpHandleBase *op_instance) const; @@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { VarHandleBase *var) const; void InsertFetchOps(const std::vector &fetch_tensors, - std::vector *fetch_ops, + std::vector *fetch_ops, std::unordered_set *fetch_dependencies, std::unordered_set *ready_ops, std::unordered_map *pending_ops, @@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { FeedFetchList *fetch_data); void PrepareOpDeps(); + void CopyOpDeps(); + + inline void RecordOps(OpHandleBase *op); + + inline void ExecutionFinal(std::vector *fetch_ops); + + inline void RunOpSync(OpHandleBase *op); + + void RunTracedOps(const std::vector &traced_ops); }; } // namespace details diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py index 645b0188d5f45935ace074ba343de246af476b41..0457e9cefdb391eb3bdb713f8a35bed769b9bce8 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py @@ -45,7 +45,8 @@ class TestFetchAndFeed(unittest.TestCase): def parallel_exe(self, use_cuda, run_parallel_exe, - use_experimental_executor=False, + use_faster_executor=False, + num_threads=4, seed=1): main_program = fluid.Program() startup = fluid.Program() @@ -72,7 +73,8 @@ class TestFetchAndFeed(unittest.TestCase): build_strategy.enable_inplace = False build_strategy.memory_optimize = False exec_strategy = fluid.ExecutionStrategy() - exec_strategy.use_experimental_executor = use_experimental_executor + exec_strategy.use_experimental_executor = use_faster_executor + exec_strategy.num_threads = num_threads train_cp = compiler.CompiledProgram(main_program).with_data_parallel( loss_name=loss.name, build_strategy=build_strategy, @@ -143,24 +145,25 @@ class TestFetchAndFeed(unittest.TestCase): if batch_id == 2: break - def test_fetch_with_threaded_executor(self): - if core.is_compiled_with_cuda(): - self.parallel_exe( - use_cuda=True, - run_parallel_exe=self.run_parallel_exe_with_fetch) - self.parallel_exe( - use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch) - - def test_fetch_with_fast_threaded_executor(self): + def check_executor(self, use_faster_executor=False, num_threads=4): if core.is_compiled_with_cuda(): self.parallel_exe( use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_fetch, - use_experimental_executor=True) + use_faster_executor=use_faster_executor, + num_threads=num_threads) self.parallel_exe( use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch, - use_experimental_executor=True) + use_faster_executor=use_faster_executor, + num_threads=num_threads) + + def test_fetch(self): + for use_faster_executor in {True, False}: + self.check_executor( + use_faster_executor=use_faster_executor, num_threads=4) + self.check_executor( + use_faster_executor=use_faster_executor, num_threads=1) def test_feed(self): if core.is_compiled_with_cuda():