diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d3f69ee9d84ac21013936f357fbfe7846853157a..c425c71160a8fa3830a5fbdae1baaed850710877 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -89,8 +89,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - backward glog lod_rank_table simple_threadpool multi_devices_graph_builder fetch_op_handle) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4432bc0245e9c48fb3726b2ea88420121eb87986..f13ac276fca01d15c2e9057f0577da00355d5787 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -10,3 +10,6 @@ cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle nccl_all_reduce_op_handle scale_loss_grad_op_handle) +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph) +cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope + simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..8da6ca889b89999e0f6f974503cea476c9de97f3 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +namespace details { + +SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr &&graph) + : graph_(std::move(graph)) {} + +SSAGraphExecutor::~SSAGraphExecutor() {} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..3b818b1a45b56351e34f9e52ec22b6d02a0c1591 --- /dev/null +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -0,0 +1,41 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/feed_fetch_type.h" + +namespace paddle { +namespace framework { +namespace details { + +class SSAGraphExecutor { + DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); + + public: + // Steal graph inside + explicit SSAGraphExecutor(std::unique_ptr &&graph); + + virtual ~SSAGraphExecutor(); + + virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; + + protected: + std::unique_ptr graph_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..86e880ed72e5c2217d77e74769c4dab468f17e69 --- /dev/null +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -0,0 +1,192 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" + +#include "paddle/fluid/framework/details/fetch_op_handle.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { +ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( + size_t num_threads, bool use_event, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph) + : SSAGraphExecutor(std::move(graph)), + pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), + local_scopes_(local_scopes), + places_(places), + fetch_ctxs_(places), + use_event_(use_event) {} + +FeedFetchList ThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + std::unordered_map pending_ops; + std::unordered_map> pending_vars; + std::unordered_set ready_ops; + + auto InsertPendingVar = [&pending_vars](VarHandleBase &var) { + pending_vars[&var] = var.generated_op_ == nullptr; + }; + + auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { + pending_ops.insert({&op_instance, op_instance.inputs_.size()}); + }; + + // Transform SSAGraph to pending_ops & pending_vars + for (auto &var_map : graph_->vars_) { + for (auto &name_pair : var_map) { + for (auto &version_pair : name_pair.second) { + InsertPendingVar(version_pair.second); + } + } + } + for (auto &var : graph_->dep_vars_) { + InsertPendingVar(*var); + } + + for (auto &op : graph_->ops_) { + if (op->inputs_.empty()) { // Special case, Op has no input. + ready_ops.insert(op.get()); + } else { + InsertPendingOp(*op); + } + } + + // Step 2. Insert FetchOps + std::vector fetch_ops; + std::vector dummy_vars; + FeedFetchList fetch_data(fetch_tensors.size()); + + std::unordered_map> fetched_vars; + + for (auto &fetch_var_name : fetch_tensors) { + for (auto &var_map : graph_->vars_) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); + } + } + } + + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors[i]; + auto &vars = fetched_vars[var_name]; + fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); + details::FetchOpHandle *op = &fetch_ops.back(); + + // FIXME: Use new device context + for (auto &p : places_) { + op->dev_ctx_[p] = fetch_ctxs_.Get(p); + } + + for (auto *var : vars) { + op->AddInput(var); + } + + dummy_vars.emplace_back(); + auto *var = &dummy_vars.back(); + var->generated_op_ = nullptr; + op->AddOutput(var); + InsertPendingVar(*var); + InsertPendingOp(*op); + } + + auto run_all_ready_ops = [&] { + for (auto *op : ready_ops) { + RunOp(pending_vars, op); + } + ready_ops.clear(); + }; + + // Step 3. Execution + while (!pending_vars.empty()) { + // 1. Run All Ready ops + run_all_ready_ops(); + + // 2. Find ready variable + VarHandleBase *ready_var = nullptr; + for (auto &pair : pending_vars) { + if (pair.second.load(std::memory_order_acquire)) { + ready_var = pair.first; + break; + } + } + + // if there is no variable ready + if (ready_var == nullptr) { + // FIXME use conditional var instead of busy wait. + // if there is an exception, throw it + if (exception_) { + throw * exception_; + } + // keep waiting the ready variables + continue; + } + + // 3. Remove the dependency of ready_var. + // Find the ready_ops after the ready_var. + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + ready_ops.insert(op); + } + } + // Keep loop until all vars are ready. + } + + // Wait FetchOps. + for (auto &fetch_op : fetch_ops) { + fetch_op.WaitAndMergeCPUTensors(); + } + + return fetch_data; +} + +void ThreadedSSAGraphExecutor::RunOp( + std::unordered_map> &pending_vars, + details::OpHandleBase *op) { + std::vector *> *ready_buffer = + new std::vector *>(); + for (auto *var : op->outputs_) { + ready_buffer->emplace_back(&pending_vars[var]); + } + + auto op_run = [ready_buffer, op, this] { + try { + VLOG(10) << op->DebugString(); + op->Run(use_event_); + for (auto *ready : *ready_buffer) { + ready->store(true, std::memory_order_release); + } + delete ready_buffer; + } catch (platform::EnforceNotMet ex) { + exception_.reset(new platform::EnforceNotMet(ex)); + } catch (...) { + LOG(FATAL) << "Unknown exception catched"; + } + }; + if (pool_) { + pool_->enqueue(op_run); + } else { + op_run(); + } +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..5b099c18c92a4a2bff288c8d560cc7ff15015d75 --- /dev/null +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ThreadPool.h" // ThreadPool in thrird party +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace details { + +class ThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph); + + // Run a SSAGraph by a thread pool + // Use topological sort algorithm + FeedFetchList Run(const std::vector &fetch_tensors) override; + + ~ThreadedSSAGraphExecutor() {} + + private: + void RunOp( + std::unordered_map> &pending_vars, + details::OpHandleBase *op); + + private: + std::unique_ptr<::ThreadPool> pool_; + std::vector local_scopes_; + std::vector places_; + platform::DeviceContextPool fetch_ctxs_; + const bool use_event_; + std::unique_ptr exception_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 88070a06a255733198ebada586cc84d78cafe626..78963fd5684e5083d03690c1a42d6cdd0ed3f6f2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,221 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" + #include "ThreadPool.h" -#include "lod_tensor.h" -#include "op_registry.h" -#include "paddle/fluid/framework/details/fetch_op_handle.h" -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include "paddle/fluid/framework/details/ssa_graph.h" + #include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" + namespace paddle { namespace framework { -using details::DummyVarHandle; -using details::FetchOpHandle; -using details::OpHandleBase; -using details::SSAGraph; -using details::VarHandleBase; - -class SSAGraphExecutor { - DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); - - public: - // Steal graph inside - explicit SSAGraphExecutor(std::unique_ptr &&graph) - : graph_(std::move(graph)) {} - - virtual ~SSAGraphExecutor() {} - - virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; - - protected: - std::unique_ptr graph_; -}; - -class ThreadedSSAGraphExecutor : public SSAGraphExecutor { - public: - ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, - const std::vector &local_scopes, - const std::vector &places, - std::unique_ptr &&graph) - : SSAGraphExecutor(std::move(graph)), - pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), - local_scopes_(local_scopes), - places_(places), - fetch_ctxs_(places), - use_event_(use_event) {} - - // Run a SSAGraph by a thread pool - // Use topological sort algorithm - FeedFetchList Run(const std::vector &fetch_tensors) override { - std::unordered_map pending_ops; - std::unordered_map> pending_vars; - std::unordered_set ready_ops; - - auto InsertPendingVar = [&pending_vars](VarHandleBase &var) { - pending_vars[&var] = var.generated_op_ == nullptr; - }; - - auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { - pending_ops.insert({&op_instance, op_instance.inputs_.size()}); - }; - - // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->vars_) { - for (auto &name_pair : var_map) { - for (auto &version_pair : name_pair.second) { - InsertPendingVar(version_pair.second); - } - } - } - for (auto &var : graph_->dep_vars_) { - InsertPendingVar(*var); - } - - for (auto &op : graph_->ops_) { - if (op->inputs_.empty()) { // Special case, Op has no input. - ready_ops.insert(op.get()); - } else { - InsertPendingOp(*op); - } - } - - // Step 2. Insert FetchOps - std::vector fetch_ops; - std::vector dummy_vars; - FeedFetchList fetch_data(fetch_tensors.size()); - - std::unordered_map> fetched_vars; - - for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->vars_) { - auto it = var_map.find(fetch_var_name); - if (it != var_map.end()) { - fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); - } - } - } - - for (size_t i = 0; i < fetch_tensors.size(); ++i) { - auto &var_name = fetch_tensors[i]; - auto &vars = fetched_vars[var_name]; - fetch_ops.emplace_back(&fetch_data, i, &local_scopes_); - details::FetchOpHandle *op = &fetch_ops.back(); - - // FIXME: Use new device context - for (auto &p : places_) { - op->dev_ctx_[p] = fetch_ctxs_.Get(p); - } - - for (auto *var : vars) { - op->AddInput(var); - } - - dummy_vars.emplace_back(); - auto *var = &dummy_vars.back(); - var->generated_op_ = nullptr; - op->AddOutput(var); - InsertPendingVar(*var); - InsertPendingOp(*op); - } - - auto run_all_ready_ops = [&] { - for (auto *op : ready_ops) { - RunOp(pending_vars, op); - } - ready_ops.clear(); - }; - - // Step 3. Execution - while (!pending_vars.empty()) { - // 1. Run All Ready ops - run_all_ready_ops(); - - // 2. Find ready variable - VarHandleBase *ready_var = nullptr; - for (auto &pair : pending_vars) { - if (pair.second.load(std::memory_order_acquire)) { - ready_var = pair.first; - break; - } - } - - // if there is no variable ready - if (ready_var == nullptr) { - // FIXME use conditional var instead of busy wait. - // if there is an exception, throw it - if (exception_) { - throw * exception_; - } - // keep waiting the ready variables - continue; - } - - // 3. Remove the dependency of ready_var. - // Find the ready_ops after the ready_var. - pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { - auto &deps = pending_ops[op]; - --deps; - if (deps == 0) { - ready_ops.insert(op); - } - } - // Keep loop until all vars are ready. - } - - // Wait FetchOps. - for (auto &fetch_op : fetch_ops) { - fetch_op.WaitAndMergeCPUTensors(); - } - - return fetch_data; - } - - ~ThreadedSSAGraphExecutor() {} - - private: - void RunOp( - std::unordered_map> &pending_vars, - details::OpHandleBase *op) { - std::vector *> *ready_buffer = - new std::vector *>(); - for (auto *var : op->outputs_) { - ready_buffer->emplace_back(&pending_vars[var]); - } - - auto op_run = [ready_buffer, op, this] { - try { - VLOG(10) << op->DebugString(); - op->Run(use_event_); - for (auto *ready : *ready_buffer) { - ready->store(true, std::memory_order_release); - } - delete ready_buffer; - } catch (platform::EnforceNotMet ex) { - exception_.reset(new platform::EnforceNotMet(ex)); - } catch (...) { - LOG(FATAL) << "Unknown exception catched"; - } - }; - if (pool_) { - pool_->enqueue(op_run); - } else { - op_run(); - } - } - - private: - std::unique_ptr<::ThreadPool> pool_; - std::vector local_scopes_; - std::vector places_; - platform::DeviceContextPool fetch_ctxs_; - const bool use_event_; - std::unique_ptr exception_; -}; - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -239,8 +35,7 @@ class ParallelExecutorPrivate { Scope *global_scope_; std::unique_ptr nccl_ctxs_; - - std::unique_ptr executor_; + std::unique_ptr executor_; }; ParallelExecutor::ParallelExecutor( @@ -274,7 +69,7 @@ ParallelExecutor::ParallelExecutor( member_->nccl_ctxs_.get()); auto graph = builder.Build(main_program); - member_->executor_.reset(new ThreadedSSAGraphExecutor( + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( num_threads, true, member_->local_scopes_, places, std::move(graph))); // Step 3. Create vars in each scope;