diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index fac9f16a89bab311c338475aef7c79015ab466be..8e69f8f0b98b3d07bfd57bf3251e32a3b3c607f7 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -100,7 +100,11 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS + threaded_ssa_graph_executor scope_buffered_ssa_graph_executor + graph graph_viz_pass multi_devices_graph_pass + multi_devices_graph_print_pass multi_devices_graph_check_pass + fast_threaded_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 8f6c4163d6ee11fbe83f603f6148c2ac6175324d..abd5459f6d47da6d1341284916b419325dc5977c 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -42,3 +42,5 @@ cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_b cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor) #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory # device_context reduce_op_handle ) +cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc + DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h index 716d674fa29bad9321fc20979775c06f26bf4679..5183be878eb49cccc68603c3fdd8023be5578036 100644 --- a/paddle/fluid/framework/details/execution_strategy.h +++ b/paddle/fluid/framework/details/execution_strategy.h @@ -19,10 +19,13 @@ namespace framework { namespace details { struct ExecutionStrategy { + enum ExecutorType { kDefault = 0, kExperimental = 1 }; + size_t num_threads_{0}; bool use_cuda_{true}; bool allow_op_delay_{false}; size_t num_iteration_per_drop_scope_{100}; + ExecutorType type_{kDefault}; }; } // namespace details diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..7606f2bc06b2ecf07c5649eeae1a2d5587a8880c --- /dev/null +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" +#include +#include +#include "paddle/fluid/framework/details/fetch_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( + const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph) + : strategy_(strategy), + local_scopes_(local_scopes), + places_(places), + graph_(std::move(graph)), + pool_(strategy.num_threads_ + + 1), // add one more thread for generate op_deps + fetch_ctxs_(places) { + auto &ops = graph_->Get("ops"); + + for (auto &op : ops) { + int dep = static_cast(op->NotReadyInputSize()); + op_deps_.emplace(op.get(), dep); + if (dep == 0) { + bootstrap_ops_.emplace_back(op.get()); + } + } + + PrepareAtomicOpDeps(); +} + +FeedFetchList FastThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + std::unique_ptr>> + op_deps = atomic_op_deps_.get(); + PrepareAtomicOpDeps(); + + paddle::framework::FeedFetchList fetches; + fetches.resize(fetch_tensors.size()); + std::unordered_map> fetched_vars; + std::vector> fetch_nodes; + std::vector> fetch_ops; + + for (auto &fetch_var_name : fetch_tensors) { + for (auto &var_map : graph_->Get("vars")) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); + } + } + } + + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors[i]; + auto fetched_var_it = fetched_vars.find(var_name); + PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), + "Cannot find fetched variable.(Perhaps the main_program " + "is not set to ParallelExecutor)"); + + auto &vars = fetched_var_it->second; + + fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); + auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, + &local_scopes_); + fetch_ops.emplace_back(op); + + for (auto &p : places_) { + op->SetDeviceContext(p, fetch_ctxs_.Get(p)); + } + + for (auto *var : vars) { + op->AddInput(var); + } + + (*op_deps)[op] = static_cast(op->NotReadyInputSize()); + } + + size_t num_complete = 0; + remaining_ = 0; + BlockingQueue complete_q; + for (auto op : bootstrap_ops_) { + RunOpAsync(op_deps.get(), op, &complete_q); + } + + while (num_complete != op_deps->size()) { + size_t num_comp = complete_q.Pop(); + if (num_comp == -1UL) { + int remaining = 0; + while (true) { + remaining = remaining_; + if (remaining == 0) { + break; + } + for (int i = 0; i < remaining; ++i) { + complete_q.Pop(); + } + } + exception_.ReThrow(); + } + num_complete += num_comp; + } + // Wait FetchOps. + if (!fetch_ops.empty()) { + fetch_ops.clear(); + } + return fetches; +} +void FastThreadedSSAGraphExecutor::RunOpAsync( + std::unordered_map> *op_deps, + OpHandleBase *op, BlockingQueue *complete_q) { + ++remaining_; + this->pool_.enqueue([=] { + OpHandleBase *op_to_run = op; + size_t complete = 0; + while (op_to_run != nullptr) { + try { + op_to_run->Run(strategy_.use_cuda_); + ++complete; + } catch (...) { + exception_.Catch(std::current_exception()); + --remaining_; + complete_q->Push(-1UL); + return; + } + auto &outputs = op_to_run->Outputs(); + op_to_run = nullptr; + for (auto &output : outputs) { + for (auto &pending_op : output->PendingOps()) { + std::atomic &deps = op_deps->at(pending_op); + if (deps.fetch_sub(1) == 1) { // pending_op ready + if (op_to_run == nullptr) { + op_to_run = pending_op; + } else { + this->RunOpAsync(op_deps, pending_op, complete_q); + } + } + } + } + } + --remaining_; + complete_q->Push(complete); + }); +} +void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { + atomic_op_deps_ = pool_.enqueue([&] { + std::unordered_map> *op_deps = + new std::unordered_map>; + for (auto &pair : op_deps_) { + (*op_deps)[pair.first] = pair.second; + } + return std::unique_ptr< + std::unordered_map>>(op_deps); + }); +} + +const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..dad3a231cba6402f57ba654a9ac5fb520b9c8f04 --- /dev/null +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -0,0 +1,64 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/details/exception_holder.h" +#include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +class Scope; +namespace details { + +class OpHandleBase; +class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph); + FeedFetchList Run(const std::vector &fetch_tensors) override; + const ir::Graph &Graph() const override; + + private: + ExecutionStrategy strategy_; + std::vector local_scopes_; + std::vector places_; + std::unique_ptr graph_; + + std::unordered_map op_deps_; + std::vector bootstrap_ops_; + + ::ThreadPool pool_; + platform::DeviceContextPool fetch_ctxs_; + std::atomic remaining_; + + void RunOpAsync(std::unordered_map> *op_deps, + OpHandleBase *op, BlockingQueue *complete_q); + + void PrepareAtomicOpDeps(); + + std::future< + std::unique_ptr>>> + atomic_op_deps_; + ExceptionHolder exception_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index ee9f9184da65467b82794c99fe3e95b108373753..3812f0abf1b7069525c4420054c61c01c908acfe 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -158,6 +158,16 @@ void OpHandleBase::RunAndRecordEvent(platform::Place p, #endif } +size_t OpHandleBase::NotReadyInputSize() const { + std::unordered_set res; + for (auto *var : inputs_) { + if (var->GeneratedOp() != nullptr) { + res.emplace(var); + } + } + return res.size(); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 2d7f18942890245249dd0619a40bb43833c9a2ee..9fbefabc841e3f6940860f60d959fee97495e4c9 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -81,6 +81,8 @@ class OpHandleBase { return res.size(); } + size_t NotReadyInputSize() const; + const std::vector &Outputs() const { return outputs_; } size_t NoDummyInputSize() const; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 275cb8c592c3c0b153d31149570cd6596b9e1a7f..81cb24bdda6b87a3d708cf5047dce05d5020a0d5 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif +#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" @@ -193,8 +194,14 @@ ParallelExecutor::ParallelExecutor( member_->local_scopes_, member_->use_cuda_, build_strategy); #endif - member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, std::move(graph))); + if (exec_strategy.type_ == ExecutionStrategy::kDefault) { + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( + exec_strategy, member_->local_scopes_, places, std::move(graph))); + } else { + member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( + exec_strategy, member_->local_scopes_, places, std::move(graph))); + } + member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6c58478b0dd0941ab4bf4d573a3c813059650ba8..67734659233515ca8110f4212a2b1553fe4e9d24 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -596,8 +596,8 @@ All parameter, weight, gradient are variables in Paddle. // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); - py::class_(pe, "ExecutionStrategy") - .def(py::init()) + py::class_ exec_strategy(pe, "ExecutionStrategy"); + exec_strategy.def(py::init()) .def_property( "num_threads", [](const ExecutionStrategy &self) { return self.num_threads_; }, @@ -624,6 +624,16 @@ All parameter, weight, gradient are variables in Paddle. [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; }); + exec_strategy.def_property( + "use_experimental_executor", + [](const ExecutionStrategy &self) { + return self.type_ == ExecutionStrategy::kExperimental; + }, + [](ExecutionStrategy &self, bool experimental) { + self.type_ = experimental ? ExecutionStrategy::kExperimental + : ExecutionStrategy::kDefault; + }); + py::class_ build_strategy(pe, "BuildStrategy"); py::enum_(build_strategy, "ReduceStrategy") diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 9be53c4609d5027dd4e482ab936a135420a346ae..74e9d5c5f91e53a315c85d428571ce45bacede8a 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -38,7 +38,8 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, - optimizer=fluid.optimizer.Adam): + optimizer=fluid.optimizer.Adam, + use_fast_executor=False): def run_executor(exe, feed, fetch_list, program=None): if isinstance(exe, fluid.ParallelExecutor): res = exe.run(fetch_list=fetch_list, feed=feed) @@ -71,6 +72,8 @@ class TestParallelExecutorBase(unittest.TestCase): startup_exe.run(startup) exec_strategy = fluid.ExecutionStrategy() exec_strategy.allow_op_delay = allow_op_delay + if use_fast_executor: + exec_strategy.use_experimental_executor = True build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index 893acd763fb21fc5040b8a24700eb947a9fe37c6..5b96d641d667eee1aa0c7c6019bf92494f777259 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -183,7 +183,9 @@ class TestMNIST(TestParallelExecutorBase): use_parallel_executor=True) self.assertAlmostEquals( - np.mean(parallel_first_loss), single_first_loss, delta=1e-6) + np.mean(parallel_first_loss), + single_first_loss, + delta=1e-6, ) self.assertAlmostEquals( np.mean(parallel_last_loss), single_last_loss, delta=1e-6) @@ -191,7 +193,7 @@ class TestMNIST(TestParallelExecutorBase): self.check_simple_fc_parallel_accuracy(True) self.check_simple_fc_parallel_accuracy(False) - def check_batchnorm_fc_convergence(self, use_cuda): + def check_batchnorm_fc_convergence(self, use_cuda, use_fast_executor): if use_cuda and not core.is_compiled_with_cuda(): return @@ -203,11 +205,13 @@ class TestMNIST(TestParallelExecutorBase): fc_with_batchnorm, feed_dict={"image": img, "label": label}, - use_cuda=use_cuda) + use_cuda=use_cuda, + use_fast_executor=use_fast_executor) def test_batchnorm_fc(self): - self.check_batchnorm_fc_convergence(True) - self.check_batchnorm_fc_convergence(False) + for use_cuda in (False, True): + for use_fast_executor in (False, True): + self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor) def test_batchnorm_fc_with_new_strategy(self): # FIXME(zcd): close this test temporally.