From 5ff1ef36ee58af535366599ebfb79515788d682f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 2 May 2018 20:28:39 +0800 Subject: [PATCH] update sparse parameter --- paddle/fluid/framework/details/CMakeLists.txt | 4 +- .../framework/details/broadcast_op_handle.cc | 108 ++++++++++--- .../framework/details/broadcast_op_handle.h | 23 ++- .../details/broadcast_op_handle_test.cc | 36 ++++- .../framework/details/gather_op_handle.cc | 53 ++++--- .../details/multi_devices_graph_builder.cc | 143 ++++++++++++++++-- .../details/multi_devices_graph_builder.h | 21 ++- .../framework/details/reduce_op_handle.cc | 14 +- .../framework/details/reduce_op_handle.h | 2 +- .../framework/details/ssa_graph_builder.cc | 11 ++ .../framework/details/ssa_graph_builder.h | 4 + paddle/fluid/framework/details/var_handle.h | 2 + paddle/fluid/framework/parallel_executor.cc | 10 +- paddle/fluid/framework/parallel_executor.h | 3 +- paddle/fluid/pybind/pybind.cc | 5 +- python/paddle/fluid/parallel_executor.py | 19 ++- .../tests/unittests/test_parallel_executor.py | 101 ++++++++++--- 17 files changed, 453 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 96c181f98..9de44beaf 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -15,12 +15,14 @@ if(WITH_GPU) dynload_cuda) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) + nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) + else() set(multi_devices_graph_builder_deps) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) + cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() -cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 33e02ab65..4f4157902 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -19,11 +19,9 @@ namespace paddle { namespace framework { namespace details { -BroadcastOpHandle::BroadcastOpHandle(const std::vector &local_scopes, - const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} void BroadcastOpHandle::RunImpl() { + if (places_.size() == 1) return; // the input and output may have dummy var. VarHandle *in_var_handle; @@ -55,27 +53,95 @@ void BroadcastOpHandle::RunImpl() { Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); - for (auto *out : out_var_handles) { - if (*out == *in_var_handle) { - continue; + if (platform::is_cpu_place(in_tensor.place())) { + for (auto *out : out_var_handles) { + if (*out == *in_var_handle) { + continue; + } + + auto &out_p = out->place_; + auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); + PADDLE_ENFORCE_EQ(out_p.which(), in_tensor.place().which(), + "Places must be all on CPU or all on CUDA."); + + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, + in_tensor.type()); + + auto dev_ctx = dev_ctxes_.at(out_p); + RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { + paddle::framework::TensorCopy( + in_tensor, out_p, *dev_ctx, + &VariableVisitor::GetMutableTensor(out_var)); + }); + } + } else { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(platform::is_gpu_place(in_tensor.place())); + VarHandle *out_handle; + int root = boost::get(in_tensor.place()).device; + std::vector> broadcast_calls; + + for (size_t j = 0; j < out_var_handles.size(); ++j) { + VarHandle *out_var_handle = out_var_handles[j]; + Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) + ->FindVar(out_var_handle->name_); + + if (*out_var_handle != *in_var_handle) { + PADDLE_ENFORCE_NOT_NULL(out_var); + PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), + in_tensor.place().which(), + "Places must be all on CPU or all on CUDA."); + VariableVisitor::ShareDimsAndLoD(*in_var, out_var); + VariableVisitor::GetMutableTensor(out_var).mutable_data( + out_var_handle->place_, in_tensor.type()); + } + + auto out_p = out_var_handle->place_; + int dev_id = boost::get(out_p).device; + + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + + void *send_recv_buffer = nullptr; + if (root == dev_id) { + send_recv_buffer = const_cast(in_tensor.data()); + out_handle = out_var_handle; + } else { + send_recv_buffer = + VariableVisitor::GetMutableTensor(out_var).mutable_data( + out_var_handle->place_); + } + + int type = platform::ToNCCLDataType(in_tensor.type()); + broadcast_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclBcast( + send_recv_buffer, in_tensor.numel(), + static_cast(type), root, comm, stream)); + }); } - auto &out_p = out->place_; - auto *out_var = var_scopes.at(out->scope_idx_)->FindVar(out->name_); - PADDLE_ENFORCE_NOT_NULL(out_var); - PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(), - "Places must be all on CPU or all on CUDA."); - - VariableVisitor::ShareDimsAndLoD(*in_var, out_var); - VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, - in_tensor.type()); - - auto dev_ctx = dev_ctxes_.at(out_p); - RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { - paddle::framework::TensorCopy( - in_tensor, out_p, *(dev_ctx), - &VariableVisitor::GetMutableTensor(out_var)); + this->RunAndRecordEvent([&] { + { + platform::NCCLGroupGuard guard; + for (auto &call : broadcast_calls) { + call(); + } + } + if (*out_handle != *in_var_handle) { + auto out_var = var_scopes.at(in_var_handle->scope_idx_) + ->FindVar(out_var_handles[0]->name_); + paddle::framework::TensorCopy( + in_tensor, in_var_handle->place_, + *(dev_ctxes_.at(in_var_handle->place_)), + &VariableVisitor::GetMutableTensor(out_var)); + } }); +#else + PADDLE_THROW("CUDA is not support."); +#endif } } diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 92420f10a..984a95008 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -24,14 +24,32 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + namespace paddle { namespace framework { namespace details { struct BroadcastOpHandle : public OpHandleBase { public: +#ifdef PADDLE_WITH_CUDA + BroadcastOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *nccl_ctxs) + : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + if (nccl_ctxs_) { + for (auto &p_ctx : nccl_ctxs_->contexts_) { + dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); + } + } + } +#else BroadcastOpHandle(const std::vector &local_scopes, - const std::vector &places); + const std::vector &places) + : local_scopes_(local_scopes), places_(places) {} +#endif std::string Name() const override; @@ -44,6 +62,9 @@ struct BroadcastOpHandle : public OpHandleBase { private: const std::vector &local_scopes_; const std::vector &places_; +#ifdef PADDLE_WITH_CUDA + const platform::NCCLContextMap *nccl_ctxs_; +#endif }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 8f1b6d161..c6e923ef7 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -35,15 +35,25 @@ struct TestBroadcastOpHandle { std::unique_ptr op_handle_; std::vector> vars_; std::vector gpu_list_; + bool use_gpu_; +#ifdef PADDLE_WITH_CUDA + std::unique_ptr nccl_ctxs_; +#endif void WaitAll() { for (size_t j = 0; j < ctxs_.size(); ++j) { ctxs_[j]->Wait(); } +#ifdef PADDLE_WITH_CUDA + if (nccl_ctxs_) { + nccl_ctxs_->WaitAll(); + } +#endif } void InitCtxOnGpu(bool use_gpu) { - if (use_gpu) { + use_gpu_ = use_gpu; + if (use_gpu_) { #ifdef PADDLE_WITH_CUDA int count = p::GetCUDADeviceCount(); if (count <= 1) { @@ -57,6 +67,7 @@ struct TestBroadcastOpHandle { gpu_list_.push_back(p); ctxs_.emplace_back(new p::CUDADeviceContext(p)); } + nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_)); #else PADDLE_THROW("CUDA is not support."); #endif @@ -67,6 +78,9 @@ struct TestBroadcastOpHandle { gpu_list_.push_back(p); ctxs_.emplace_back(new p::CPUDeviceContext(p)); } +#ifdef PADDLE_WITH_CUDA + nccl_ctxs_.reset(nullptr); +#endif } } @@ -82,7 +96,21 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); + if (use_gpu_) { +#ifdef PADDLE_WITH_CUDA + op_handle_.reset( + new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); +#else + PADDLE_THROW("CUDA is not support."); +#endif + } else { +#ifdef PADDLE_WITH_CUDA + op_handle_.reset( + new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); +#else + op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); +#endif + } auto* in_var_handle = new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); @@ -97,7 +125,9 @@ struct TestBroadcastOpHandle { op_handle_->AddInput(dummy_var_handle); for (size_t j = 0; j < gpu_list_.size(); ++j) { - op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); + if (!use_gpu_) { + op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); + } VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 3ed772391..43145f44c 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -25,6 +25,7 @@ GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, : local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { + if (places_.size() == 1) return; // the input and output may have dummy var. auto in_var_handles = DynamicCast(inputs_); @@ -53,55 +54,53 @@ void GatherOpHandle::RunImpl() { PADDLE_ENFORCE(pre_in_var->IsType(), "Currently, gather_op only can gather SelectedRows."); - auto pre_place = in_0_handle->place_; - PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(), - "The place of input and output should be the same."); - // Wait input done, this Wait is asynchronous operation WaitInputVarGenerated(in_var_handles); std::vector out_rows; std::vector in_tensors; - std::vector in_places; - auto &pre_in = pre_in_var->Get(); + auto &pre_in_value = pre_in_var->Get(); // gather the inputs for (auto *in_handle : in_var_handles) { - auto in_p = in_handle->place_; - in_places.push_back(in_p); - PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), - "Places must be all on CPU or all on CUDA."); auto *in_var = var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); - auto &in_sr = in_var->Get(); + PADDLE_ENFORCE_NOT_NULL(in_var); + + auto &in_sr_value = in_var->Get(); - PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), + PADDLE_ENFORCE_EQ(in_sr_value.place().which(), pre_in_value.place().which(), + "Places must be all on CPU or all on GPU."); + PADDLE_ENFORCE_EQ(in_sr_value.value().type(), pre_in_value.value().type(), "The type of input is not consistent."); - PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), + PADDLE_ENFORCE_EQ(in_sr_value.height(), pre_in_value.height(), "The height of inputs is not consistent."); - PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), + PADDLE_ENFORCE_EQ(in_sr_value.GetCompleteDims(), + pre_in_value.GetCompleteDims(), "The dims of inputs is not consistent."); - auto &in_sr_rows = in_sr.rows(); + auto &in_sr_rows = in_sr_value.rows(); out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end()); - - in_tensors.emplace_back(in_sr.value()); + in_tensors.emplace_back(in_sr_value.value()); } // write the output auto &out_place = out_var_handle->place_; - auto out_scope_idx = out_var_handle->scope_idx_; - auto out_var = var_scopes.at(out_scope_idx)->FindVar(out_var_handle->name_); - - auto out = out_var->GetMutable(); - out->set_height(pre_in.height()); - out->set_rows(out_rows); + PADDLE_ENFORCE_EQ(out_place.which(), pre_in_value.place().which(), + "Places must be all on CPU or all on GPU."); + auto out_var = + var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); + PADDLE_ENFORCE_NOT_NULL(out_var); + auto out_value = out_var->GetMutable(); + out_value->set_height(pre_in_value.height()); + out_value->set_rows(out_rows); size_t rows = out_rows.size(); - DDim out_dim = pre_in.GetCompleteDims(); + DDim out_dim = pre_in_value.GetCompleteDims(); out_dim[0] = static_cast(rows); - out->mutable_value()->Resize(out_dim); - out->mutable_value()->mutable_data(out_place, pre_in.value().type()); - Tensor *out_tensor = out->mutable_value(); + out_value->mutable_value()->Resize(out_dim); + out_value->mutable_value()->mutable_data(out_place, + pre_in_value.value().type()); + Tensor *out_tensor = out_value->mutable_value(); // copy auto dev_ctx = dev_ctxes_[out_place]; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index daba9bf2d..0b4a51807 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -11,9 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include +#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" @@ -34,21 +36,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, bool use_default_grad_scale, - platform::NCCLContextMap *nccl_ctxs) + const std::vector &local_scopes, + platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale, + bool use_nccl_allreduce) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes), - nccl_ctxs_(nccl_ctxs) { + nccl_ctxs_(nccl_ctxs), + use_nccl_allreduce_(use_nccl_allreduce) { #else + MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, bool use_default_grad_scale) + const std::vector &local_scopes, bool use_default_grad_scale, + bool use_nccl_allreduce) : loss_var_name_(loss_var_name), places_(places), - local_scopes_(local_scopes) { + local_scopes_(local_scopes), + use_nccl_allreduce_(use_nccl_allreduce) { #endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); @@ -114,6 +121,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); + size_t cur_device_id = 0; + + std::vector> var_name_on_devices; + std::vector> bcast_var_name_set; + + var_name_on_devices.resize(places_.size()); + bcast_var_name_set.resize(places_.size()); + // Find "send" op first for split is in front of send. OpDesc *send_op = GetSendOpDesc(program); @@ -132,19 +147,44 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } is_forwarding = false; } else { - CreateComputationalOps(&result, *op, places_.size()); - if (!is_forwarding) { + int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); + if (op_dev_id == -1) { // var on all device + CreateComputationalOps(&result, *op, places_.size()); + } else { + CreateComputationalOp(&result, *op, op_dev_id); + for (auto &var_name : op->OutputArgumentNames()) { + var_name_on_devices[op_dev_id].emplace(var_name); + } + } + + if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { - InsertNCCLAllReduceOp(&result, og); + if (use_nccl_allreduce_) { + InsertNCCLAllReduceOp(&result, og); + } else { + CreateReduceOp(&result, cur_device_id, og); + var_name_on_devices[cur_device_id].emplace(og); + bcast_var_name_set[cur_device_id].emplace( + og.substr(0, og.size() - strlen(kGradVarSuffix))); + cur_device_id = (cur_device_id + 1) % places_.size(); + } } } } } } + // Insert BCast Ops + for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(&result, bcast_name, dev_id); + } + } + /* Dependency graph has been constructed. However, there are still data harzaeds need to be handled. @@ -165,6 +205,60 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } +int MultiDevSSAGraphBuilder::GetOpDeviceID( + const std::vector> &var_name_on_devices, + const OpDesc &op) const { + if (use_nccl_allreduce_) { + return -1; + } + + int var_dev_id = -1; + for (auto &var_name : op.InputArgumentNames()) { + if (var_dev_id != -1) break; + for (size_t i = 0; i < var_name_on_devices.size(); ++i) { + if (var_name_on_devices[i].count(var_name)) { + var_dev_id = static_cast(i); + break; + } + } + } + return var_dev_id; +} + +void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, + const std::string &p_name, + size_t dev_id) const { +#ifdef PADDLE_WITH_CUDA + auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); +#else + auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); +#endif + + result->ops_.emplace_back(op_handle); + auto *in = result->vars_.at(dev_id).at(p_name).back().get(); + op_handle->AddInput(in); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &vars = result->vars_.at(dev_id).at(p_name); + auto &p = places_[i]; + auto *out_var = new VarHandle(vars.size(), i, p_name, p); + vars.emplace_back(out_var); + op_handle->AddOutput(out_var); +#ifndef ADDLE_WITH_CUDA + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); +#endif + } +} + +void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, + const OpDesc &op, + int dev_id) const { + result->ops_.emplace_back( + new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + CreateOpHandleIOs(result, op, dev_id); +} + OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( const ProgramDesc &program) const { for (auto *op : program.Block(0).AllOps()) { @@ -174,7 +268,6 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( } return nullptr; } - void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA @@ -247,6 +340,35 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, } } +VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp( + SSAGraph *result, int dst_dev_id, const std::string &og) const { +#ifdef PADDLE_WITH_CUDA + result->ops_.emplace_back( + new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else + result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); +#endif + auto *op_handle = result->ops_.back().get(); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &vars = result->vars_[i][og]; +#ifndef PADDLE_WITH_CUDA + auto &p = places_[i]; + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); +#endif + PADDLE_ENFORCE(!vars.empty()); + auto &prev_grad = vars.back(); + op_handle->AddInput(prev_grad.get()); + } + auto &vars = result->vars_[dst_dev_id][og]; + auto var = + new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]); + vars.emplace_back(var); + op_handle->AddOutput(var); + return var; +} + void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; @@ -263,6 +385,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { return op.OutputArgumentNames().size() == 1 && op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); } + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index bad47458e..824349430 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -13,8 +13,8 @@ // limitations under the License. #pragma once - #include +#include #include #include "paddle/fluid/framework/details/ssa_graph_builder.h" @@ -27,6 +27,7 @@ class NCCLContextMap; namespace framework { class Scope; namespace details { + class MultiDevSSAGraphBuilder : public SSAGraphBuilder { public: #ifdef PADDLE_WITH_CUDA @@ -34,14 +35,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - bool skip_scale_loss, - platform::NCCLContextMap *nccl_ctxs); + platform::NCCLContextMap *nccl_ctxs, + bool use_default_grad_scale, bool use_nccl_allreduce); #else MultiDevSSAGraphBuilder(const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - bool use_default_grad_scale); + bool use_default_grad_scale, bool use_nccl_allreduce); #endif std::unique_ptr Build(const ProgramDesc &program) const override; @@ -59,6 +60,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; #endif + bool use_nccl_allreduce_; bool use_default_grad_scale_; bool IsScaleLossOp(const OpDesc &op) const; @@ -74,6 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { size_t num_places) const; void CreateScaleLossGradOp(SSAGraph *result) const; + VarHandle *CreateReduceOp(SSAGraph *result, int dst_dev_id, + const std::string &og) const; + void CreateComputationalOp(SSAGraph *result, const OpDesc &op, + int dev_id) const; bool IsParameterGradientOnce( const std::string &og, @@ -81,6 +87,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; + void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, + size_t dev_id) const; + + int GetOpDeviceID( + const std::vector> &var_name_on_devices, + const OpDesc &op) const; + /** * Get send op in the global block of program. * nullptr if not found. diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 409e8f72b..f06cb024c 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -22,6 +22,7 @@ namespace framework { namespace details { void ReduceOpHandle::RunImpl() { + if (places_.size() == 1) return; // the input and output may have dummy var. auto in_var_handles = DynamicCast(inputs_); @@ -52,19 +53,18 @@ void ReduceOpHandle::RunImpl() { // Wait input done, this Wait is asynchronous operation WaitInputVarGenerated(in_var_handles); auto pre_place = in_0_handle->place_; - std::vector in_places; + std::vector in_places; // used to get dev_ctx auto pre_in_tensor = VariableVisitor::GetMutableTensor(pre_in_var); for (auto *in_handle : in_var_handles) { - auto in_p = in_handle->place_; - PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), - "Places must be all on CPU or all on CUDA."); - in_places.emplace_back(in_p); + in_places.emplace_back(in_handle->place_); auto in_var = var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); PADDLE_ENFORCE_NOT_NULL(in_var); auto in_tensor = VariableVisitor::GetMutableTensor(in_var); + PADDLE_ENFORCE_EQ(pre_in_tensor.place().which(), in_tensor.place().which(), + "Places must be all on CPU or all on GPU."); PADDLE_ENFORCE_EQ(in_tensor.type(), pre_in_tensor.type(), "The type of input is not consistent."); } @@ -84,11 +84,11 @@ void ReduceOpHandle::RunImpl() { std::vector lod_tensors = GetInputValues(in_var_handles, var_scopes); - if (paddle::platform::is_cpu_place(pre_place)) { + if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) { ReduceLoDTensor func(lod_tensors, out_var->GetMutable()); VisitDataType(ToDataType(lod_tensors[0]->type()), func); - } else if (paddle::platform::is_gpu_place(pre_place)) { + } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) { #ifdef PADDLE_WITH_CUDA auto pre_in = pre_in_var->Get(); VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 9746b3bdb..59731d348 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -55,7 +55,7 @@ struct ReduceOpHandle : public OpHandleBase { std::string Name() const override; - bool IsMultiDeviceTransfer() override { return false; }; + bool IsMultiDeviceTransfer() override { return true; }; protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 6a5675275..153874471 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -47,6 +47,17 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { } } +VarHandle *SSAGraphBuilder::GetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + size_t place_offset) { + auto &var_holders = graph->vars_[place_offset]; + auto &var_holder = var_holders[each_var_name]; + if (var_holder.empty()) { + return nullptr; + } + return var_holder.rbegin()->get(); +} + VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( SSAGraph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 64e5d9308..dafd4e8d6 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -48,6 +48,10 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); + static VarHandle *GetLatestVarHandle(SSAGraph *graph, + const std::string &each_var_name, + size_t place_offset); + // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 9f7fd69e6..99e5eb2b4 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -66,6 +66,8 @@ struct VarHandle : public VarHandleBase { return o.generated_op_ == generated_op_ && o.name_ == name_ && o.scope_idx_ == scope_idx_; } + + bool operator!=(const VarHandle& o) const { return !this->operator==(o); } }; // Dummy Variable. It is used to represent dependencies between operators diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 4712efeff..f45936182 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor( const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, bool allow_op_delay, - bool use_default_grad_scale) + bool use_default_grad_scale, bool use_nccl_allreduce) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; @@ -93,11 +93,11 @@ ParallelExecutor::ParallelExecutor( #ifdef PADDLE_WITH_CUDA details::MultiDevSSAGraphBuilder builder( member_->places_, loss_var_name, params, member_->local_scopes_, - use_default_grad_scale, member_->nccl_ctxs_.get()); + member_->nccl_ctxs_.get(), use_default_grad_scale, use_nccl_allreduce); #else - details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, - params, member_->local_scopes_, - use_default_grad_scale); + details::MultiDevSSAGraphBuilder builder( + member_->places_, loss_var_name, params, member_->local_scopes_, + use_default_grad_scale, use_nccl_allreduce); #endif auto graph = builder.Build(main_program); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ecd107d81..b2e8ddd05 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -40,7 +40,8 @@ class ParallelExecutor { const ProgramDesc& main_program, const std::string& loss_var_name, Scope* scope, const std::vector& local_scopes, - bool allow_op_delay, bool use_default_grad_scale); + bool allow_op_delay, bool use_default_grad_scale, + bool use_nccl_allreduce); ~ParallelExecutor(); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c925686f8..4b4de6f20 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -502,11 +502,12 @@ All parameter, weight, gradient are variables in Paddle. const std::unordered_set &bcast_vars, const ProgramDesc &main_program, const std::string &loss_var_name, Scope *scope, std::vector &local_scopes, - bool allow_op_delay, bool use_default_grad_scale) { + bool allow_op_delay, bool use_default_grad_scale, + bool use_nccl_allreduce) { new (&self) ParallelExecutor( num_threads, use_event, places, params, bcast_vars, main_program, loss_var_name, scope, local_scopes, - allow_op_delay, use_default_grad_scale); + allow_op_delay, use_default_grad_scale, use_nccl_allreduce); }) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) // NOTE: even we return a vec* to Python use reference policy. diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index f4128dcbe..46c18c689 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -30,7 +30,8 @@ class ParallelExecutor(object): num_threads=None, allow_op_delay=False, share_vars_from=None, - use_default_grad_scale=True): + use_default_grad_scale=True, + use_nccl_allreduce=True): """ ParallelExecutor can run program in parallel. @@ -43,9 +44,17 @@ class ParallelExecutor(object): training. allow_op_delay(bool, default False): Whether to delay and buffer some operators together for scheduling or not, which may - improve performance in some cases, defalut False. + improve performance in some cases, default False. share_vars_from(ParallelExecutor, default None): If provied, it will share variables from the specified ParallelExecutor. + use_nccl_allreduce(bool, default True): Whether to use nccl_allreduce + or not, if set True, the communication between different + devices by nccl allReduce, which doesn't support updating sparse + parameter, if set False, the communication between different + devices by reduce_op and broadcast_op, which will distribute all + the parameter gradients evenly to different device and updates + the parameters, and finally broadcast to other device, this method + support updating sparse parameter. Default True. use_default_grad_scale(bool, default True): If set True, a default scale value equal to `1./device_count` would be multiplied to gradients of each device and scaled gradients would be @@ -93,7 +102,7 @@ class ParallelExecutor(object): if use_cuda: # Experiments on se-resnext shows that too many threads hurt # performance. Worth tunning for other models in the future. - num_threads = len(self._places) + num_threads = len(self._places) * 2 else: num_threads = min( len(self._places) * 2, multiprocessing.cpu_count()) @@ -129,7 +138,9 @@ class ParallelExecutor(object): scope, local_scopes, allow_op_delay, - use_default_grad_scale) + use_default_grad_scale, + use_nccl_allreduce) + self.scope = scope def run(self, fetch_list, feed=None, feed_dict=None): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index c783a1424..8dc14b88b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -205,7 +205,8 @@ class TestParallelExecutorBase(unittest.TestCase): allow_op_delay=False, feed_dict=None, seed=None, - use_parallel_executor=True): + use_parallel_executor=True, + use_nccl_allreduce=True): def run_executor(exe, feed, fetch_list, program=None): if isinstance(exe, fluid.ParallelExecutor): res = exe.run(fetch_list=fetch_list, feed=feed) @@ -234,7 +235,10 @@ class TestParallelExecutorBase(unittest.TestCase): if use_parallel_executor: exe = fluid.ParallelExecutor( - True, loss_name=loss.name, allow_op_delay=allow_op_delay) + True, + loss_name=loss.name, + allow_op_delay=allow_op_delay, + use_nccl_allreduce=use_nccl_allreduce) else: exe = fluid.Executor(place=place) @@ -280,17 +284,25 @@ class TestMNIST(TestParallelExecutorBase): fluid.recordio_writer.convert_reader_to_recordio_file( './mnist.recordio', reader, feeder) - def test_simple_fc(self): + def check_simple_fc_convergence(self, use_nccl_allreduce=True): self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net, allow_op_delay=True) img = numpy.zeros(shape=[32, 784], dtype='float32') label = numpy.ones(shape=[32, 1], dtype='int64') self.check_network_convergence( - simple_fc_net, feed_dict={"image": img, - "label": label}) + simple_fc_net, + feed_dict={"image": img, + "label": label}, + use_nccl_allreduce=use_nccl_allreduce) + + def test_simple_fc_with_nccl_allreduce(self): + self.check_simple_fc_convergence(True) - def test_simple_fc_parallel_accuracy(self): + def test_simple_fc_with_reduce_op(self): + self.check_simple_fc_convergence(False) + + def check_simple_fc_parallel_accuracy(self, use_nccl_allreduce=True): img = numpy.zeros(shape=[32, 784], dtype='float32') label = numpy.ones(shape=[32, 1], dtype='int64') single_first_loss, single_last_loss = self.check_network_convergence( @@ -304,20 +316,35 @@ class TestMNIST(TestParallelExecutorBase): seed=1000, feed_dict={"image": img, "label": label}, - use_parallel_executor=True) + use_parallel_executor=True, + use_nccl_allreduce=use_nccl_allreduce) for p_f in parallel_first_loss: self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6) for p_l in parallel_last_loss: self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6) - def test_batchnorm_fc(self): + def test_simple_fc_parallel_accuracy_with_nccl_allreduce(self): + self.check_simple_fc_parallel_accuracy(True) + + def test_simple_fc_parallel_accuracy_with_reduce_op(self): + self.check_simple_fc_parallel_accuracy(False) + + def check_batchnorm_fc_convergence(self, use_nccl_allreduce): self.check_network_convergence(fc_with_batchnorm) img = numpy.zeros(shape=[32, 784], dtype='float32') label = numpy.ones(shape=[32, 1], dtype='int64') self.check_network_convergence( - fc_with_batchnorm, feed_dict={"image": img, - "label": label}) + fc_with_batchnorm, + feed_dict={"image": img, + "label": label}, + use_nccl_allreduce=use_nccl_allreduce) + + def test_batchnorm_fc_with_nccl_allreduce(self): + self.check_batchnorm_fc_convergence(True) + + def test_batchnorm_fc_with_reduce_op(self): + self.check_batchnorm_fc_convergence(False) class TestResnet(TestParallelExecutorBase): @@ -339,14 +366,21 @@ class TestResnet(TestParallelExecutorBase): # fluid.recordio_writer.convert_reader_to_recordio_file( # "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress) - def test_resnet(self): + def check_resnet_convergence(self, use_nccl_allreduce): import functools batch_size = 2 self.check_network_convergence( functools.partial( SE_ResNeXt50Small, batch_size=batch_size), iter=20, - batch_size=batch_size) + batch_size=batch_size, + use_nccl_allreduce=use_nccl_allreduce) + + def test_resnet_with_nccl_allreduce(self): + self.check_resnet_convergence(True) + + def test_resnet_with_reduce_op(self): + self.check_resnet_convergence(False) class ModelHyperParams(object): @@ -510,7 +544,7 @@ class TestTransformer(TestParallelExecutorBase): class ParallelExecutorTestingDuringTraining(unittest.TestCase): - def test_parallel_testing(self): + def check_network_convergence(self, use_nccl_allreduce): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -531,12 +565,16 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): feed_dict = {'image': image, 'label': label} train_exe = fluid.ParallelExecutor( - use_cuda=True, loss_name=loss.name, main_program=main) + use_cuda=True, + loss_name=loss.name, + main_program=main, + use_nccl_allreduce=use_nccl_allreduce) test_exe = fluid.ParallelExecutor( use_cuda=True, main_program=test_program, - share_vars_from=train_exe) + share_vars_from=train_exe, + use_nccl_allreduce=use_nccl_allreduce) for i in xrange(5): test_loss, = test_exe.run([loss.name], feed=feed_dict) @@ -550,6 +588,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): "Train loss: " + str(train_loss) + "\n Test loss:" + str(test_loss)) + def test_parallel_testing_with_nccl_allreduce(self): + self.check_network_convergence(use_nccl_allreduce=True) + + def test_parallel_testing_with_reduce_op(self): + self.check_network_convergence(use_nccl_allreduce=False) + import paddle.dataset.conll05 as conll05 import paddle.fluid as fluid @@ -568,21 +612,26 @@ embedding_name = 'emb' def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, - **ignored): + is_sparse, use_nccl_allreduce, **ignored): # 8 features predicate_embedding = fluid.layers.embedding( input=predicate, + is_sparse=is_sparse, size=[pred_dict_len, word_dim], dtype='float32', param_attr='vemb') mark_embedding = fluid.layers.embedding( - input=mark, size=[mark_dict_len, mark_dim], dtype='float32') + input=mark, + is_sparse=is_sparse, + size=[mark_dict_len, mark_dim], + dtype='float32') word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] emb_layers = [ fluid.layers.embedding( size=[word_dict_len, word_dim], + is_sparse=is_sparse, input=x, param_attr=fluid.ParamAttr( name=embedding_name, trainable=False)) for x in word_input @@ -632,7 +681,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, class TestCRFModel(unittest.TestCase): - def test_all(self): + def check_network_convergence(self, is_sparse, use_nccl_allreduce): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -652,6 +701,7 @@ class TestCRFModel(unittest.TestCase): name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1) mark = fluid.layers.data( name='mark_data', shape=[1], dtype='int64', lod_level=1) + feature_out = db_lstm(**locals()) target = fluid.layers.data( name='target', shape=[1], dtype='int64', lod_level=1) @@ -679,7 +729,10 @@ class TestCRFModel(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup) - pe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) + pe = fluid.ParallelExecutor( + use_cuda=True, + loss_name=avg_cost.name, + use_nccl_allreduce=use_nccl_allreduce) feeder = fluid.DataFeeder( feed_list=[ @@ -694,3 +747,13 @@ class TestCRFModel(unittest.TestCase): print map(numpy.array, pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name]))[0] + + def test_update_sparse_parameter(self): + self.check_network_convergence(is_sparse=True, use_nccl_allreduce=False) + + def test_update_dense_parameter_with_nccl_allreduce(self): + self.check_network_convergence(is_sparse=False, use_nccl_allreduce=True) + + def test_update_dense_parameter_with_reduce_op(self): + self.check_network_convergence( + is_sparse=False, use_nccl_allreduce=False) -- GitLab