From 843dc3cdbd970aca8f79d6a6d41313bed04eb059 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Mon, 18 Jan 2021 16:32:42 +0800 Subject: [PATCH] [Kunlun]PR3: add xpu executor, multi xpu card train function optimization (#30317) --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 + .../bind_threaded_ssa_graph_executor.cc | 316 ++++++++++++++++++ .../bind_threaded_ssa_graph_executor.h | 107 ++++++ .../fluid/framework/details/op_handle_base.cc | 20 -- paddle/fluid/framework/parallel_executor.cc | 22 +- paddle/fluid/platform/device_context.cc | 2 +- 7 files changed, 445 insertions(+), 26 deletions(-) create mode 100644 paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc create mode 100644 paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f96b9475f56..4feffe65f73 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -265,7 +265,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor - graph build_strategy collective_helper + graph build_strategy bind_threaded_ssa_graph_executor collective_helper fast_threaded_ssa_graph_executor variable_helper) cc_library(executor_cache SRCS executor_cache.cc DEPS executor) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 0c9e30fd195..dce256ebc47 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor) #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory # device_context reduce_op_handle ) +cc_library(bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc + DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context) cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context) cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle) diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc new file mode 100644 index 00000000000..d334520a93f --- /dev/null +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc @@ -0,0 +1,316 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h" +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/fetch_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/profiler.h" + +#if defined(PADDLE_WITH_XPU) +namespace paddle { +namespace framework { +namespace details { + +static std::atomic exec_op_count_; +static std::atomic error_state; + +BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor( + const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, + const std::vector &places, ir::Graph *graph) + : strategy_(strategy), + local_scopes_(local_scopes), + local_exec_scopes_(local_exec_scopes), + places_(places), + graph_(graph), + prepare_pool_(1), + multi_device_op_pool_(1) { + for (uint32_t i = 0; i < places.size(); i++) { + pool_.emplace_back(std::unique_ptr<::ThreadPool>(new ::ThreadPool(1))); + } + int index = 0; + for (uint32_t i = 0; i < places.size(); i++) { + int id = BOOST_GET_CONST(platform::XPUPlace, places_[i]).device; + if (place_to_index_.find(id) == place_to_index_.end()) { + place_to_index_[id] = index; + index++; + } + } + for (auto &op : ir::FilterByNodeWrapper(*graph_)) { + int dep = static_cast(op->NotReadyInputSize()); + op_deps_.emplace(op, dep); + if (dep == 0) { + bootstrap_ops_.emplace_back(op); + } + } + PADDLE_ENFORCE_GT(op_deps_.size(), 0, + platform::errors::PreconditionNotMet( + "The graph doesn't have operators.")); + PrepareAtomicOpDeps(); +} + +static std::vector get_children(OpHandleBase *op) { + auto &outputs = op->Outputs(); + std::vector ret; + for (auto &output : outputs) { + ret.insert(ret.end(), output->PendingOps().begin(), + output->PendingOps().end()); + } + return ret; +} + +static std::vector get_parents(OpHandleBase *op) { + auto &inputs = op->Inputs(); + std::vector ret; + for (auto &input : inputs) { + if (input->GeneratedOp() != nullptr) { + ret.push_back(input->GeneratedOp()); + } + } + return ret; +} + +FetchResultType BindThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors, bool return_merged) { + VLOG(3) << "enter BindThreadedSSAGraphExecutor Run"; + return RunMainStream(fetch_tensors, return_merged); +} + +// use 2 streams to run op. The first stream is main stream and will run +// most op exclude op depending on multi device(e.g., all_reduce, fetch op) +FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( + const std::vector &fetch_tensors, bool return_merged) { + VLOG(3) << "enter MainStream Run"; + std::unique_ptr> + op_deps = atomic_op_deps_.get(); + PrepareAtomicOpDeps(); + + error_state = 0; + paddle::framework::FetchResultType fetches; + if (return_merged) { + fetches = FetchList(fetch_tensors.size()); + } else { + fetches = FetchUnmergedList(fetch_tensors.size()); + } + std::unordered_map> fetched_vars; + std::vector fetch_ops; + std::vector ready_fetch_ops; + auto ready_ops = std::make_shared>(); + exception_.Clear(); + + InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(), + &fetch_ops, &ready_fetch_ops, return_merged); + for (auto cur_op : bootstrap_ops_) { + ready_ops->Push(cur_op); + } + for (auto cur_op : ready_fetch_ops) { + ready_ops->Push(cur_op); + } + + exec_op_count_ = 0; + + platform::XPUPlace cur_place; + std::size_t cur_count = 0; + + while (cur_count < op_deps_.size()) { + cur_count++; + auto cur_op = ready_ops->Pop(); + if (cur_op == nullptr) { + // sleep a while to make sure worker thread quit + sleep(10); + exec_op_count_ = op_deps_.size(); + break; + } + auto dev_ctxes_ = cur_op->DeviceContext(); + if (cur_op->IsMultiDeviceTransfer()) { + RunMultiDeviceOpAsync(cur_op, op_deps.get(), ready_ops); + continue; + } else { + cur_place = + BOOST_GET_CONST(platform::XPUPlace, dev_ctxes_.begin()->first); + int cur_index = place_to_index_[cur_place.device]; + RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index); + } + } + while (exec_op_count_ < op_deps_.size()) { + } + + // Wait FetchOps. + ClearFetchOp(graph_, &fetch_ops); + if (exception_.IsCaught()) { + ExecutionFinal(&fetch_ops); + } + return fetches; +} + +void BindThreadedSSAGraphExecutor::InsertFetchOps( + const std::vector &fetch_tensors, FetchResultType *fetches, + std::unordered_map> *fetched_vars, + std::unordered_map *op_deps, + std::vector *fetch_ops, + std::vector *ready_fetch_ops, bool return_merged) { + std::unordered_set fetch_tensor_set(fetch_tensors.begin(), + fetch_tensors.end()); + for (auto &fetch_var_name : fetch_tensor_set) { + for (auto &var_map : graph_->Get(kGraphVars)) { + auto it = var_map.find(fetch_var_name); + if (it != var_map.end()) { + (*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin()); + } + } + } + + for (size_t i = 0; i < fetch_tensors.size(); ++i) { + auto &var_name = fetch_tensors.at(i); + auto fetched_var_it = fetched_vars->find(var_name); + PADDLE_ENFORCE_NE( + fetched_var_it, fetched_vars->end(), + platform::errors::PreconditionNotMet( + "Cannot find fetched variable(%s) in current computation graph. " + "Possible reasons are:\n" + " 1. The variable to be fetched is not defined in main program.\n" + " 2. The variable to be fetched is not an input or output of any " + "operator.\n" + " 3. Confirm that you have used the fetch `Variable` format " + "instead of the string literal('%s') in `fetch_list` parameter " + "when using `executor.run` method. In other words, the format of " + "`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) " + "is recommended.", + var_name, var_name)); + + auto &vars = fetched_var_it->second; + + ir::Node *fetch_node = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_, + &local_exec_scopes_, return_merged); + fetch_ops->emplace_back(op); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + for (auto &p : places_) { + op->SetDeviceContext(p, pool.Get(p)); + } + + for (auto *var : vars) { + op->AddInput(var); + } + + int dep = static_cast(op->NotReadyInputSize()); + (*op_deps)[op].dep_num = dep; + (*op_deps)[op].op = op; + if (dep == 0) { + ready_fetch_ops->emplace_back(op); + } + } +} + +void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( + OpHandleBase *op, + std::unordered_map *op_deps, + std::shared_ptr> ready_ops) { + multi_device_op_pool_.enqueue([=] { + try { + if (error_state == 0 && LIKELY(!strategy_.dry_run_)) { + auto dev_ctxes = op->DeviceContext(); + auto &inputs = op->Inputs(); + for (auto &input : inputs) { + auto dev_ctxes = input->GeneratedOp()->DeviceContext(); + for (auto &item : dev_ctxes) { + ((platform::XPUDeviceContext *)(item.second))->Wait(); + } + } + op->Run(strategy_.use_device_); + auto &outputs = op->Outputs(); + for (auto &output : outputs) { + for (auto &pending_op : output->PendingOps()) { + std::atomic &deps = op_deps->at(pending_op).dep_num; + if (deps.fetch_sub(1) == 1) { + ready_ops->Push(pending_op); + } + } + } + } else if (error_state) { + ready_ops->Push(nullptr); + } + } catch (...) { + error_state = 1; + ready_ops->Push(nullptr); + exception_.Catch(std::current_exception()); + } + exec_op_count_++; + }); +} + +void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( + OpHandleBase *op, + std::unordered_map *op_deps, + std::shared_ptr> ready_ops, int index) { + pool_[index]->enqueue([=] { + try { + if (error_state == 0 && LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_device_); + auto &outputs = op->Outputs(); + for (auto &output : outputs) { + for (auto &pending_op : output->PendingOps()) { + std::atomic &deps = op_deps->at(pending_op).dep_num; + if (deps.fetch_sub(1) == 1) { + ready_ops->Push(pending_op); + } + } + } + } else if (error_state) { + ready_ops->Push(nullptr); + } + } catch (...) { + error_state = 1; + ready_ops->Push(nullptr); + exception_.Catch(std::current_exception()); + } + exec_op_count_++; + }); +} + +void BindThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { + atomic_op_deps_ = prepare_pool_.enqueue([&] { + auto *op_deps = new std::unordered_map; + for (auto &pair : op_deps_) { + (*op_deps)[pair.first].dep_num = pair.second; + (*op_deps)[pair.first].op = pair.first; + } + return std::unique_ptr< + std::unordered_map>(op_deps); + }); +} + +const ir::Graph &BindThreadedSSAGraphExecutor::Graph() const { return *graph_; } + +void BindThreadedSSAGraphExecutor::ExecutionFinal( + std::vector *fetch_ops) { + VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it"; + ClearFetchOp(graph_, fetch_ops); + exception_.ReThrow(); +} + +} // namespace details +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h new file mode 100644 index 00000000000..87c1908944e --- /dev/null +++ b/paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h @@ -0,0 +1,107 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/details/exception_holder.h" +#include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/ssa_graph_executor.h" + +#if defined(PADDLE_WITH_XPU) +namespace paddle { +namespace framework { +class Scope; +namespace details { + +struct RunningItem { + std::atomic dep_num; + OpHandleBase *op; +}; + +class OpHandleBase; +class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { + public: + BindThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, + const std::vector &local_scopes, + const std::vector &local_exec_scopes, + const std::vector &places, + ir::Graph *graph); + // FeedFetchList Run(const std::vector &fetch_tensors) override; + // Run a SSAGraph by a thread pool + // Use topological sort algorithm + FetchResultType Run(const std::vector &fetch_tensors, + bool return_merged) override; + const ir::Graph &Graph() const override; + + private: + FetchResultType RunMainStream(const std::vector &fetch_tensors, + bool return_merged); + + // Note(zcd): the ThreadPool should be placed last so that ThreadPool should + // be destroyed first. + ExecutionStrategy strategy_; + std::vector local_scopes_; + std::vector local_exec_scopes_; + std::vector places_; + ir::Graph *graph_; + + std::unordered_map op_deps_; + std::unordered_map place_to_index_; + std::vector bootstrap_ops_; + + std::unique_ptr stream_op_count_; + + std::future< + std::unique_ptr>> + atomic_op_deps_; + ExceptionHolder exception_; + + std::vector> pool_; + ::ThreadPool prepare_pool_; + ::ThreadPool multi_device_op_pool_; + + void RunOpAsyncMainStream( + OpHandleBase *op, + std::unordered_map *op_deps, + std::shared_ptr> ready_ops, int index); + + void RunMultiDeviceOpAsync( + OpHandleBase *op, + std::unordered_map *op_deps, + std::shared_ptr> ready_ops); + + void PrepareAtomicOpDeps(); + + int get_pool_thread_index(int device_id); + + inline void ExecutionFinal(std::vector *fetch_ops); + + void InsertFetchOps( + const std::vector &fetch_tensors, FetchResultType *fetches, + std::unordered_map> + *fetched_vars, + std::unordered_map *op_deps, + std::vector *fetch_ops, + std::vector *ready_fetch_ops, bool return_merged); +}; +} // namespace details +} // namespace framework +} // namespace paddle + +#endif diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index e2f4f453ccf..eeff0f3d46d 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -215,13 +215,6 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); -#endif - } else if (platform::is_xpu_place(place)) { -#ifdef PADDLE_WITH_XPU - dev_ctxes_.at(place)->Wait(); -#else - PADDLE_THROW( - platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. @@ -271,19 +264,6 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); -#endif - } else if (platform::is_xpu_place(in_var_handle->place())) { -#ifdef PADDLE_WITH_XPU - PADDLE_ENFORCE_EQ( - platform::is_same_place(place, in_var_handle->place()), true, - platform::errors::InvalidArgument( - "The place of output(%s) is not consistent with the " - "place of current op(%s).", - in_var_handle->Name(), Name())); - dev_ctxes_.at(place)->Wait(); -#else - PADDLE_THROW( - platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } // There are nothing to do when the place is CPUPlace. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index bfc3b7c7017..3ddd7cc9182 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/op_handle_base.h" @@ -933,10 +934,23 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph)); } else { - VLOG(3) << "use FastThreadedSSAGraphExecutor"; - member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, - member_->places_, graph)); + if (member_->use_device_ == p::kXPU) { +#if defined(PADDLE_WITH_XPU) + VLOG(3) << "use BindThreadedSSAGraphExecutor"; + member_->executor_.reset(new details::BindThreadedSSAGraphExecutor( + exec_strategy, member_->local_scopes_, + member_->local_exec_scopes_, member_->places_, graph)); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't use XPU device since it's not compiled with XPU," + "Please recompile or reinstall Paddle with XPU support.")); +#endif + } else { + VLOG(3) << "use FastThreadedSSAGraphExecutor"; + member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( + exec_strategy, member_->local_scopes_, + member_->local_exec_scopes_, member_->places_, graph)); + } } final_graphs.emplace_back(graph); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index d9e9443e752..4d952ecda0c 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -211,7 +211,7 @@ void XPUDeviceContext::Wait() const { "XPU API return wrong value[%d], please check whether " "Baidu Kunlun Card is properly installed.", ret)); - xpu_wait(); + xpu_wait(context_->xpu_stream); } Place XPUDeviceContext::GetPlace() const { return place_; } -- GitLab