diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 89b5c6847f15b3f2a270fe1e7db9e590549e8982..caaf418076bdd43a2d989c0ac318dbba85fa313c 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) +cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 128a5344fbb8c64c36ade24475bd0d99bdb3e0f5..bea9489bbd353048db384fcb2e6baf2ee71c5b77 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" #ifdef PADDLE_WITH_CUDA @@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs) + platform::NCCLContextMap *nccl_ctxs, bool distributed) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes), + distributed_(distributed), nccl_ctxs_(nccl_ctxs) { #else MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes) + const std::vector &local_scopes, bool distributed) : loss_var_name_(loss_var_name), places_(places), - local_scopes_(local_scopes) { + local_scopes_(local_scopes), + distributed_(distributed) { #endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); } } +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, + const platform::Place &p, + const size_t &i) const { + auto *op_handle = result->ops_.back().get(); + + auto var_names = op->InputArgumentNames(); + + for (auto &each_var_name : var_names) { + VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); + op_handle->AddInput(var); + } + var_names = op->OutputArgumentNames(); + + for (auto &each_var_name : var_names) { + CreateOpOutput(result, op_handle, each_var_name, p, i); + } +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); @@ -72,6 +93,17 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } } + // append send op if program is distributed trainer main program. + // always use the first device + if (is_forwarding && distributed_ && op->Type() == "send") { + auto &p = places_[0]; + auto *s = local_scopes_[0]; + size_t i = 0; + result.ops_.emplace_back(new SendOpHandle(*op, s, p)); + CreateOpHandleIOs(&result, op, p, i); + continue; + } + for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; auto *s = local_scopes_[i]; @@ -81,18 +113,19 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( op_handle->dev_ctxes_[p] = const_cast( platform::DeviceContextPool::Instance().Get(p)); - auto var_names = op->InputArgumentNames(); + CreateOpHandleIOs(&result, op, p, i); + // auto var_names = op->InputArgumentNames(); - for (auto &each_var_name : var_names) { - VarHandle *var = - CreateOrGetLatestVarHandle(&result, each_var_name, p, i); - op_handle->AddInput(var); - } - var_names = op->OutputArgumentNames(); + // for (auto &each_var_name : var_names) { + // VarHandle *var = + // CreateOrGetLatestVarHandle(&result, each_var_name, p, i); + // op_handle->AddInput(var); + // } + auto var_names = op->OutputArgumentNames(); - for (auto &each_var_name : var_names) { - CreateOpOutput(&result, op_handle, each_var_name, p, i); - } + // for (auto &each_var_name : var_names) { + // CreateOpOutput(&result, op_handle, each_var_name, p, i); + // } if (is_forwarding) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index d3c8e582cf2cdf26198822e4bd2602883622df21..004d6d50ab8e21888341072782cd430f3d41c1b8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/ssa_graph_builder.h" namespace paddle { @@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs); + platform::NCCLContextMap *nccl_ctxs, + bool distributed = false); #else MultiDevSSAGraphBuilder(const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes); + const std::vector &local_scopes, + bool distributed = false); #endif std::unique_ptr Build(const ProgramDesc &program) const override; + private: + void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, + const size_t &i) const; + private: std::string loss_var_name_; const std::vector &places_; const std::vector &local_scopes_; std::unordered_set grad_names_; + bool distributed_; #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index bd2a0a9c2981bafee361c687058807bf78996a68..ae5637b804525c8753dd25024ecddd1d08f2d747 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -18,61 +18,24 @@ namespace paddle { namespace framework { namespace details { -SendOpHandle::SendOpHandle(const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap &ctxs) - : local_scopes_(local_scopes), places_(places) {} +SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, + const Scope *local_scope, + const platform::Place &place) + : op_(framework::OpRegistry::CreateOp(op_desc)), + local_scope_(local_scope), + place_(place) {} void SendOpHandle::RunImpl() { - if (inputs_.size() == 1) { - return; // No need to all reduce when GPU count = 1; - } else { - // Wait input done - for (auto *in : inputs_) { - auto &p = static_cast(in)->place_; - in->generated_op_->Wait(dev_ctxes_[p]); - } - - auto &var_name = static_cast(this->inputs_[0])->name_; - int dtype = -1; - size_t numel = 0; - - std::vector> all_reduce_calls; - - for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto &p = places_[i]; - auto *s = local_scopes_[i]; - int dev_id = boost::get(p).device; - - auto &lod_tensor = s->FindVar(var_name)->Get(); - void *buffer = const_cast(lod_tensor.data()); - - if (dtype == -1) { - dtype = platform::ToNCCLDataType(lod_tensor.type()); - } - - if (numel == 0) { - numel = static_cast(lod_tensor.numel()); - } - - auto &nccl_ctx = nccl_ctxs_.at(dev_id); - auto stream = nccl_ctx.stream(); - auto comm = nccl_ctx.comm_; - all_reduce_calls.emplace_back([=] { - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), ncclSum, - comm, stream)); - }); - } - - platform::NCCLGroupGuard guard; - for (auto &call : all_reduce_calls) { - call(); - } + // Wait input done + for (auto *in : inputs_) { + auto &p = static_cast(in)->place_; + in->generated_op_->Wait(dev_ctxes_[p]); } + + op_->Run(*local_scope_, place_); } -std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } +std::string SendOpHandle::Name() const { return "send"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/send_op_handle.h b/paddle/fluid/framework/details/send_op_handle.h index 515f1a10a8d90cfea82f2525519a62bb789fe419..e7857c1f234fc4617462b8b065cfc4ea68e8c3aa 100644 --- a/paddle/fluid/framework/details/send_op_handle.h +++ b/paddle/fluid/framework/details/send_op_handle.h @@ -19,6 +19,8 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/nccl_helper.h" @@ -27,19 +29,18 @@ namespace framework { namespace details { struct SendOpHandle : public OpHandleBase { - const std::vector &local_scopes_; - const std::vector &places_; - const platform::NCCLContextMap &nccl_ctxs_; + std::unique_ptr op_; + const Scope* local_scope_; + const platform::Place& place_; - SendOpHandle(const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap &ctxs); + SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, + const platform::Place& place); std::string Name() const override; // Delay and buffer nccl_all_reduce together can significantly increase // performance. Disable this feature by returning false. - bool IsMultiDeviceTransfer() override { return true; }; + bool IsMultiDeviceTransfer() override { return false; }; protected: void RunImpl() override; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 39d4017861f4d2ac2e8e85c3d70440a43e6cdc71..8bd9161fcb2c38fb71e4493afec2095c1b9833dd 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -951,6 +951,13 @@ class Block(object): if var.type == core.VarDesc.VarType.STEP_SCOPES: ret_var = self.create_var( name=var.name, persistable=var.persistable, type=var.type) + elif var.type == core.VarDesc.VarType.SELECTED_ROWS: + ret_var = self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + persistable=True) else: ret_var = self.create_var( name=var.name,