From b123ce88a17ac18dd24ec396d18c1eac7c832442 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 2 Apr 2018 01:10:00 -0700 Subject: [PATCH] Add enable/disable for delayed ops --- .../details/threaded_ssa_graph_executor.cc | 12 ++++++++---- .../details/threaded_ssa_graph_executor.h | 4 +++- paddle/fluid/framework/parallel_executor.cc | 6 +++--- paddle/fluid/framework/parallel_executor.h | 6 ++++-- paddle/fluid/pybind/pybind.cc | 4 ++-- python/paddle/fluid/parallel_executor.py | 9 +++++++-- .../tests/unittests/test_parallel_executor.py | 16 ++++++++++++++-- 7 files changed, 41 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 65fbfb65e1..1f96b9dc62 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -23,14 +23,15 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( size_t num_threads, bool use_event, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph) + std::unique_ptr &&graph, bool allow_op_delay) : SSAGraphExecutor(std::move(graph)), pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), local_scopes_(local_scopes), places_(places), fetch_ctxs_(places), use_event_(use_event), - running_ops_(0) {} + running_ops_(0), + allow_op_delay_(allow_op_delay) {} void ThreadedSSAGraphExecutor::RunDelayedOps( const std::unordered_set &delayed_ops) { @@ -119,7 +120,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - if (op->IsMultiDeviceTransfer()) { + if (op->IsMultiDeviceTransfer() && allow_op_delay_) { delayed_ops.insert(op); delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); ready_vars.Extend(op->outputs_); @@ -138,7 +139,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } // Step 3. Execution - while (!pending_vars.empty()) { + while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) { // 1. Run All Ready ops run_all_ready_ops(); @@ -181,6 +182,9 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } // Keep loop until all vars are ready. } + PADDLE_ENFORCE(ready_ops.empty()); + PADDLE_ENFORCE(delayed_ops.empty()); + PADDLE_ENFORCE(blocked_by_delayed_ops.empty()); ++computation_count_; auto sync_computation = [&] { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 6193b897e4..79cfc26b46 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -75,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph, + bool allow_op_delay); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -97,6 +98,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { const bool use_event_; std::unique_ptr exception_; std::atomic running_ops_; + bool allow_op_delay_; size_t computation_count_{0}; size_t max_async_computation{100}; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 002a6d362f..1788514324 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -48,7 +48,7 @@ ParallelExecutor::ParallelExecutor( const std::vector &places, const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, - const std::string &loss_var_name, Scope *scope) + const std::string &loss_var_name, Scope *scope, bool allow_op_delay) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; @@ -83,8 +83,8 @@ ParallelExecutor::ParallelExecutor( auto graph = builder.Build(main_program); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - num_threads, use_event, member_->local_scopes_, places, - std::move(graph))); + num_threads, use_event, member_->local_scopes_, places, std::move(graph), + allow_op_delay)); // Step 3. Create vars in each scope; for (auto *scope : member_->local_scopes_) { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 503efa2e44..964b476234 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,8 +14,9 @@ limitations under the License. */ #pragma once -#include +#include #include +#include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" @@ -37,7 +38,8 @@ class ParallelExecutor { const std::unordered_set& params, const ProgramDesc& startup_program, const ProgramDesc& main_program, - const std::string& loss_var_name, Scope* scope); + const std::string& loss_var_name, Scope* scope, + bool allow_op_delay); void Run(const std::vector& fetch_tensors, const std::string& fetched_var_name = "fetched_var"); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index e1b1bbec97..b0a3f06a88 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle. const std::unordered_set ¶ms, const ProgramDesc &startup_program, const ProgramDesc &main_program, const std::string &loss_var_name, - Scope *scope) { + Scope *scope, bool allow_op_delay) { new (&self) ParallelExecutor(num_threads, use_event, places, params, startup_program, main_program, - loss_var_name, scope); + loss_var_name, scope, allow_op_delay); }) .def("run", &ParallelExecutor::Run); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index fec7d6899c..a2c830b3c9 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -21,7 +21,11 @@ __all__ = ['ParallelExecutor'] class ParallelExecutor(object): - def __init__(self, loss_name, use_cuda, num_threads=None): + def __init__(self, + loss_name, + use_cuda, + num_threads=None, + allow_op_delay=False): places = [] if use_cuda: for i in xrange(core.get_cuda_device_count()): @@ -57,7 +61,8 @@ class ParallelExecutor(object): startup.desc, main.desc, loss_name, - scope) + scope, + allow_op_delay) self.scope = scope def run(self, fetch_list): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 95d0f9da47..60130298af 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -184,7 +184,8 @@ class TestParallelExecutorBase(unittest.TestCase): method, memory_opt=True, iter=10, - batch_size=None): + batch_size=None, + allow_op_delay=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -194,7 +195,10 @@ class TestParallelExecutorBase(unittest.TestCase): if memory_opt: fluid.memory_optimize(main) - exe = fluid.ParallelExecutor(loss_name=loss.name, use_cuda=True) + exe = fluid.ParallelExecutor( + loss_name=loss.name, + use_cuda=True, + allow_op_delay=allow_op_delay) if batch_size is not None: batch_size *= fluid.core.get_cuda_device_count() begin = time.time() @@ -236,9 +240,11 @@ class TestMNIST(TestParallelExecutorBase): def test_simple_fc(self): self.check_network_convergence(simple_fc_net) + self.check_network_convergence(simple_fc_net, allow_op_delay=True) def test_batchnorm_fc(self): self.check_network_convergence(fc_with_batchnorm) + self.check_network_convergence(fc_with_batchnorm, allow_op_delay=True) class TestResnet(TestParallelExecutorBase): @@ -268,6 +274,12 @@ class TestResnet(TestParallelExecutorBase): SE_ResNeXt152, batch_size=batch_size), iter=20, batch_size=batch_size) + self.check_network_convergence( + functools.partial( + SE_ResNeXt152, batch_size=batch_size), + iter=20, + batch_size=batch_size, + allow_op_delay=True) class ModelHyperParams(object): -- GitLab