diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a167511160d074c13ca1dca36b4f2c5eeea4bb93..e22c7f8a403be1f825f5c92118122a4832fd30de 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -184,7 +184,7 @@ endif() target_link_libraries(executor garbage_collector) cc_library(parallel_executor SRCS parallel_executor.cc DEPS - threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor + threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor graph build_strategy fast_threaded_ssa_graph_executor variable_helper) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index c1ba6606f1064750a9d7e087ded1ec3634bcc4a5..01c24b0d824d30e0f6f3368453161e9b1a6f9074 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -79,6 +79,8 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) +cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) + cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory device_context broadcast_op_handle) cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 9b26fdd545c9a5de19412de30b1ca717204426ac..d3e4573e228705aac7f35a29abac07449b2ea628 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -27,6 +27,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), graphs_(std::move(graphs)) { + VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); // set the correct size of thread pool to each device. diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index df0ff772c9d35c88ec5a6112525c56aa92d359b9..f8911cd9ad4bc7edb90daa016bf5e59c973c329f 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -116,7 +116,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Convert graph to run on multi-devices. void AppendMultiDevPass(const BuildStrategy &strategy) { ir::Pass *multi_devices_pass; - if (strategy_.is_distribution_) { + + if (strategy_.async_mode_) { + multi_devices_pass = AppendPass("async_multi_devices_pass").get(); + } else if (strategy_.is_distribution_) { multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); } else { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 15c2e01b6142571883c759efb1e26b609be9adb4..16324839657364ec47c7045b2de775693071d943 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -86,6 +86,7 @@ struct BuildStrategy { // num_trainers is 1, so the current fields of build_strategy doesn't tell if // it's distributed model. bool is_distribution_{false}; + bool async_mode_{false}; int num_trainers_{1}; int trainer_id_{0}; std::vector trainers_endpoints_; diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 75f922d2cca6855a67be7284ae407e549a1a1afb..d7a4b5692b3c36359677eabb6a9ba9425256984c 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -975,3 +975,5 @@ REGISTER_MULTI_DEVICES_PASS( paddle::framework::details::AllReduceSSAGraphBuilder); REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, paddle::framework::details::DistSSAGraphBuilder); +REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass, + paddle::framework::details::AsyncSSAGraphBuilder); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 6d4386538ea7d0cc318647c92282af9d598fa699..e91397816c3a983e1a9d211ea9368d78e2bc89f5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -55,7 +55,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { bool UseGPU() const; - bool NeedCollectiveOps() const; + virtual bool NeedCollectiveOps() const; bool IsScaleLossOp(ir::Node *node) const; @@ -116,6 +116,20 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { virtual void InsertPostprocessOps(ir::Graph *result) const {} }; +class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { + protected: + virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const {} + + bool NeedCollectiveOps() const override { return false; } + + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { + return false; + } + + virtual void InsertPostprocessOps(ir::Graph *result) const {} +}; + class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: int GetVarDeviceID(const std::string &varname) const; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f61c9e3a91146704faa6c5b1058137bef67d2a3e..4173b39e10dc30dbc703fe136e9bf43684f2c275 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" @@ -282,10 +283,19 @@ ParallelExecutor::ParallelExecutor( graphs.push_back(std::move(graph)); } #else - std::unique_ptr graph = build_strategy.Apply( - main_program, member_->places_, loss_var_name, member_->local_scopes_, - member_->nranks_, member_->use_cuda_); - graphs.push_back(std::move(graph)); + if (build_strategy.async_mode_) { + for (size_t i = 0; i < member_->places_.size(); ++i) { + std::unique_ptr graph = build_strategy.Apply( + main_program, {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_); + graphs.push_back(std::move(graph)); + } + } else { + std::unique_ptr graph = build_strategy.Apply( + main_program, member_->places_, loss_var_name, member_->local_scopes_, + member_->nranks_, member_->use_cuda_); + graphs.push_back(std::move(graph)); + } #endif auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { @@ -323,23 +333,31 @@ ParallelExecutor::ParallelExecutor( "please don't pass loss_var_name."; } } - - if (build_strategy.enable_parallel_graph_) { + if (build_strategy.async_mode_) { + VLOG(3) << "use AsyncSSAGraphExecutor"; + member_->executor_.reset(new details::AsyncSSAGraphExecutor( + exec_strategy, member_->local_scopes_, member_->places_, + std::move(graphs))); + } else if (build_strategy.enable_parallel_graph_) { + VLOG(3) << "use ParallelSSAGraphExecutor"; member_->executor_.reset(new details::ParallelSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, std::move(graphs))); } else { if (exec_strategy.type_ == ExecutionStrategy::kDefault) { + VLOG(3) << "use ThreadedSSAGraphExecutor"; member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, std::move(graphs[0]))); } else { + VLOG(3) << "use FastThreadedSSAGraphExecutor"; member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, std::move(graphs[0]))); } } + VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); @@ -401,14 +419,22 @@ void ParallelExecutor::BCastParamsToDevices( auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); - // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. - if (member_->use_all_reduce_ || member_->use_cuda_ || - var == "@LR_DECAY_COUNTER@") { + auto share_memory = [&] { t->Resize(dims); t->mutable_data(cpu, main_tensor.type()); paddle::framework::TensorCopy(main_tensor, cpu, t); + }; + + auto copy_memory = [&] { t->ShareDataWith(main_tensor); }; + + // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. + if (member_->build_strategy_.async_mode_) { + share_memory(); + } else if (member_->use_all_reduce_ || member_->use_cuda_ || + var == "@LR_DECAY_COUNTER@") { + copy_memory(); } else { - t->ShareDataWith(main_tensor); + share_memory(); } } } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f3f4854a9efbcf5ab325e7f6aec81135c018dcd5..88d12c69b77de901a90f7fe6fcc05e3bffed1efc 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1030,6 +1030,9 @@ All parameter, weight, gradient are variables in Paddle. "is_distribution", [](const BuildStrategy &self) { return self.is_distribution_; }, [](BuildStrategy &self, bool b) { self.is_distribution_ = b; }) + .def_property("async_mode", + [](const BuildStrategy &self) { return self.async_mode_; }, + [](BuildStrategy &self, bool b) { self.async_mode_ = b; }) .def_property( "memory_early_delete", [](const BuildStrategy &self) { return self.memory_early_delete_; },