diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 8f6c4163d6ee11fbe83f603f6148c2ac6175324d..abd5459f6d47da6d1341284916b419325dc5977c 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -42,3 +42,5 @@ cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_b cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor) #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory # device_context reduce_op_handle ) +cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc + DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..08b0db4270c525c9041d86bc83d51b270a1422ac --- /dev/null +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" +#include +#include +#include "paddle/fluid/framework/details/fetch_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( + const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph) + : strategy_(strategy), + local_scopes_(local_scopes), + places_(places), + graph_(std::move(graph)), + pool_(strategy.num_threads_ + + 1), // add one more thread for generate op_deps + fetch_ctxs_(places) { + auto &ops = graph_->Get("ops"); + + for (auto &op : ops) { + int dep = static_cast(op->NotReadyInputSize()); + op_deps_.emplace(op.get(), dep); + if (dep == 0) { + bootstrap_ops_.emplace_back(op.get()); + } + } + + PrepareAtomicOpDeps(); +} + +FeedFetchList FastThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + std::unique_ptr>> + op_deps = atomic_op_deps_.get(); + PrepareAtomicOpDeps(); + + paddle::framework::FeedFetchList fetches; + fetches.resize(fetch_tensors.size()); + std::unordered_map> fetched_vars; + std::vector> fetch_nodes; + std::vector> fetch_ops; + + for (auto &fetch_var_name : fetch_tensors) { + for (auto &var_map : graph_->Get("vars")) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); + } + } + } + + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors[i]; + auto fetched_var_it = fetched_vars.find(var_name); + PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), + "Cannot find fetched variable.(Perhaps the main_program " + "is not set to ParallelExecutor)"); + + auto &vars = fetched_var_it->second; + + fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); + auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, + &local_scopes_); + fetch_ops.emplace_back(op); + + for (auto &p : places_) { + op->SetDeviceContext(p, fetch_ctxs_.Get(p)); + } + + for (auto *var : vars) { + op->AddInput(var); + } + + (*op_deps)[op] = static_cast(op->NotReadyInputSize()); + } + + size_t num_complete = 0; + remaining_ = 0; + BlockingQueue complete_q; + for (auto op : bootstrap_ops_) { + RunOpAsync(op_deps.get(), op, &complete_q); + } + + while (num_complete != op_deps->size()) { + size_t num_comp = complete_q.Pop(); + if (num_comp == -1UL) { + int remaining = remaining_; + for (int i = 0; i < remaining; ++i) { + complete_q.Pop(); + } + LOG(FATAL) << "On exception thrown, not implemented"; + } + num_complete += num_comp; + } + // Wait FetchOps. + if (!fetch_ops.empty()) { + fetch_ops.clear(); + } + return fetches; +} +void FastThreadedSSAGraphExecutor::RunOpAsync( + std::unordered_map> *op_deps, + OpHandleBase *op, BlockingQueue *complete_q) { + ++remaining_; + this->pool_.enqueue([=] { + OpHandleBase *op_to_run = op; + size_t complete = 0; + while (op_to_run != nullptr) { + try { + op_to_run->Run(strategy_.use_cuda_); + ++complete; + } catch (...) { + --remaining_; + complete_q->Push(-1UL); + return; + } + auto &outputs = op_to_run->Outputs(); + op_to_run = nullptr; + for (auto &output : outputs) { + for (auto &pending_op : output->PendingOps()) { + std::atomic &deps = op_deps->at(pending_op); + if (deps.fetch_sub(1) == 1) { // pending_op ready + if (op_to_run == nullptr) { + op_to_run = pending_op; + } else { + this->RunOpAsync(op_deps, pending_op, complete_q); + } + } + } + } + } + --remaining_; + complete_q->Push(complete); + }); +} +void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { + atomic_op_deps_ = pool_.enqueue([&] { + std::unordered_map> *op_deps = + new std::unordered_map>; + for (auto &pair : op_deps_) { + (*op_deps)[pair.first] = pair.second; + } + return std::unique_ptr< + std::unordered_map>>(op_deps); + }); +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..48f741ceccfb46499a35c2e921ab8f71da7bc6b9 --- /dev/null +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -0,0 +1,62 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/details/exception_holder.h" +#include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +namespace paddle { +namespace framework { +class Scope; +namespace details { + +class OpHandleBase; +class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, + const std::vector &local_scopes, + const std::vector &places, + std::unique_ptr &&graph); + FeedFetchList Run(const std::vector &fetch_tensors) override; + + private: + ExecutionStrategy strategy_; + std::vector local_scopes_; + std::vector places_; + std::unique_ptr graph_; + + std::unordered_map op_deps_; + std::vector bootstrap_ops_; + + ::ThreadPool pool_; + platform::DeviceContextPool fetch_ctxs_; + std::atomic remaining_; + + void RunOpAsync(std::unordered_map> *op_deps, + OpHandleBase *op, BlockingQueue *complete_q); + + void PrepareAtomicOpDeps(); + + std::future< + std::unique_ptr>>> + atomic_op_deps_; +}; +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index ee9f9184da65467b82794c99fe3e95b108373753..3812f0abf1b7069525c4420054c61c01c908acfe 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -158,6 +158,16 @@ void OpHandleBase::RunAndRecordEvent(platform::Place p, #endif } +size_t OpHandleBase::NotReadyInputSize() const { + std::unordered_set res; + for (auto *var : inputs_) { + if (var->GeneratedOp() != nullptr) { + res.emplace(var); + } + } + return res.size(); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 2d7f18942890245249dd0619a40bb43833c9a2ee..9fbefabc841e3f6940860f60d959fee97495e4c9 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -81,6 +81,8 @@ class OpHandleBase { return res.size(); } + size_t NotReadyInputSize() const; + const std::vector &Outputs() const { return outputs_; } size_t NoDummyInputSize() const;