From d0342f12aabd51979e86249f5450fb099393c449 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 16 Apr 2018 11:33:15 +0800 Subject: [PATCH] Simplify DelayOps Logic --- .../details/threaded_ssa_graph_executor.cc | 43 +++++++------------ .../tests/unittests/test_parallel_executor.py | 5 ++- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index a371ee10fe0..927078bc64d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -51,8 +51,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // together since we currently cannot overlap computation and memcpy streams. // Should revisit it if overlapping is available. std::unordered_set delayed_ops; - std::unordered_set blocked_by_delayed_ops; - std::unordered_set delayed_vars; auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { pending_vars.insert(&var); @@ -122,24 +120,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( InsertPendingOp(*op); } - auto run_all_ready_ops = [&] { - for (auto *op : ready_ops) { - if (op->IsMultiDeviceTransfer() && allow_op_delay_) { - delayed_ops.insert(op); - delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); - ready_vars.Extend(op->outputs_); - continue; - } + auto run_all_ops = [&](std::unordered_set &set) { + for (auto *op : set) { running_ops_++; RunOp(&ready_vars, op); } - ready_ops.clear(); + set.clear(); }; // Step 3. Execution - while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) { + while (!pending_vars.empty()) { // 1. Run All Ready ops - run_all_ready_ops(); + // Keep loop until all vars are ready. + // + // NOTE: DelayedOps have a lower priority. It will be scheduled after all + // ready_ops have been performed. + if (ready_ops.empty() && allow_op_delay_) { + run_all_ops(delayed_ops); + } else { + run_all_ops(ready_ops); + } // 2. Find ready variable bool timeout; @@ -160,29 +160,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto &deps = pending_ops[op]; --deps; if (deps == 0) { - if (delayed_vars.find(ready_var) != delayed_vars.end()) { - blocked_by_delayed_ops.insert(op); + if (op->IsMultiDeviceTransfer() && allow_op_delay_) { + delayed_ops.insert(op); } else { ready_ops.insert(op); } } } } - // When there are no other ops to schedule, schedule buffered delayed - // ops and unblock other ops. - if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { - RunDelayedOps(delayed_ops); - delayed_ops.clear(); - for (auto *op : blocked_by_delayed_ops) { - ready_ops.insert(op); - } - blocked_by_delayed_ops.clear(); - } - // Keep loop until all vars are ready. } PADDLE_ENFORCE(ready_ops.empty()); - PADDLE_ENFORCE(delayed_ops.empty()); - PADDLE_ENFORCE(blocked_by_delayed_ops.empty()); // Wait FetchOps. if (!fetch_ops.empty()) { diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 83d22fd799e..0cd88d61e1d 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -206,18 +206,19 @@ class TestParallelExecutorBase(unittest.TestCase): feed_dict={}): main = fluid.Program() startup = fluid.Program() + startup.random_seed = 1 # Fix random seed with fluid.program_guard(main, startup): loss = method(use_feed=len(feed_dict) > 0) adam = fluid.optimizer.Adam() adam.minimize(loss) if memory_opt: fluid.memory_optimize(main) - place = fluid.CUDAPlace(0) startup_exe = fluid.Executor(place) startup_exe.run(startup) - exe = fluid.ParallelExecutor(True, loss_name=loss.name) + exe = fluid.ParallelExecutor( + True, loss_name=loss.name, allow_op_delay=allow_op_delay) if batch_size is not None: batch_size *= fluid.core.get_cuda_device_count() begin = time.time() -- GitLab