diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 4c5384af613752251fd8336874d86474ffb9515b..a86542efec2118bad62fe7cb49620d5004373fac 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -186,10 +186,12 @@ std::unique_ptr BuildStrategy::Apply( #else const bool use_cuda) const { #endif + VLOG(3) << "apply all passes"; // Create a default one if not finalized by user. CreatePassesFromStrategy(false); for (std::shared_ptr &pass : pass_builder_->AllPasses()) { + VLOG(3) << "apply " << pass->Type(); if (IsMultiDevPass(pass->Type())) { pass->Erase(kPlaces); pass->SetNotOwned>(kPlaces, &places); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 33ccee6aa0a94b8fd8308214d6144ae832d40bab..823697495edf32c9f2d6339f44f84be551f6474c 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,6 +19,7 @@ namespace paddle { namespace framework { namespace ir { std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + VLOG(3) << "apply pass -> " << Type(); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 87f0f307d30bc90a43a698c3766b16c975f0635e..b690829216920e013d25b7003b9ef2067bbac1f7 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -77,6 +77,8 @@ Scope& Scope::NewScope() const { return *child; } +Scope* Scope::NewTmpScope() const { return new Scope(this); } + Variable* Scope::Var(const std::string& name) { SCOPE_VARS_WRITER_LOCK return VarInternal(name); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index f0915d2eee072b0bcd53f37dad5ef9d801c87172..0e9b8edeb3e5f1afdfef83520e5d602b9b7bad43 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -55,6 +55,8 @@ class Scope { /// Mark it to const because that new kid scope cannot change parent scope. Scope& NewScope() const; + Scope* NewTmpScope() const; + /// Create a variable with given name if it doesn't exist. /// Caller doesn't own the returned Variable. Variable* Var(const std::string& name); diff --git a/paddle/fluid/framework/variable_helper.cc b/paddle/fluid/framework/variable_helper.cc index fc4525549caeebb06dea766ccb123b5ebc6d5b13..d59f3ea7dcc2fc7e63355c3d063873f42e5310ca 100644 --- a/paddle/fluid/framework/variable_helper.cc +++ b/paddle/fluid/framework/variable_helper.cc @@ -27,7 +27,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void InitializeVariable(Variable* var, proto::VarType::Type var_type) { +void InitializeVariable(Variable *var, proto::VarType::Type var_type) { if (var_type == proto::VarType::LOD_TENSOR) { var->GetMutable(); } else if (var_type == proto::VarType::SELECTED_ROWS) { @@ -37,7 +37,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { } else if (var_type == proto::VarType::FETCH_LIST) { var->GetMutable(); } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); + var->GetMutable>(); } else if (var_type == proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { @@ -56,5 +56,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { var_type); } } + +void CopyVariable(const Variable &src_var, Variable *dst_var) { + // only support cpu now + auto cpu_place = platform::CPUPlace(); + + if (src_var.IsType()) { + auto *tmp_grad_tensor = dst_var->GetMutable(); + auto &src_tensor = src_var.Get(); + tmp_grad_tensor->set_lod(src_tensor.lod()); + framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor); + } else if (src_var.IsType()) { + auto &src_slr = src_var.Get(); + auto *tmp_grad_slr = dst_var->GetMutable(); + tmp_grad_slr->set_rows(src_slr.rows()); + tmp_grad_slr->set_height(src_slr.height()); + auto &src_t = src_slr.value(); + auto *dst_t = tmp_grad_slr->mutable_value(); + framework::TensorCopy(src_t, cpu_place, dst_t); + } else { + PADDLE_THROW("unknown var type to copy"); + } +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/variable_helper.h b/paddle/fluid/framework/variable_helper.h index 0e0c72c3621dce0a6b372f9a9110a63fbc0a1d71..f8e90d53967524c37e68f7dc647fb597322aa3c4 100644 --- a/paddle/fluid/framework/variable_helper.h +++ b/paddle/fluid/framework/variable_helper.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/variable.h" namespace paddle { namespace framework { -void InitializeVariable(Variable *var, proto::VarType::Type var_type); +void InitializeVariable(Variable* var, proto::VarType::Type var_type); +void CopyVariable(const Variable& src_var, Variable* dst_var); } } diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index fc28fe818dc0bd2a8607118c015b6b5fd168fb43..1301467fa74cb98c8d2cd87aca1448af673bb39b 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(WITH_GRPC) else() set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) - set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) @@ -50,8 +50,11 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) -cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) +cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) +cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) +cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) +cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool) if(WITH_GPU) cc_test(collective_server_test SRCS collective_server_test.cc DEPS sendrecvop_rpc executor ${RPC_DEPS} diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc new file mode 100644 index 0000000000000000000000000000000000000000..a88b76447488c4e04bbbf0419120e903ddeba1f0 --- /dev/null +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/distributed/communicator.h" + +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/distributed/parameter_recv.h" +#include "paddle/fluid/operators/distributed/parameter_send.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { +namespace distributed { + +static inline void MergeVars(const std::string &var_name, + const std::vector> &vars, + Scope *scope) { + PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); + auto cpu_place = platform::CPUPlace(); + auto &var0 = vars[0]; + auto *out_var = scope->Var(var_name); + if (var0->IsType()) { + auto *out_t = out_var->GetMutable(); + auto *out_ptr = out_t->mutable_data( + var0->Get().dims(), cpu_place); + auto numel = out_t->numel(); + for (auto i = 0; i < numel; ++i) { + out_ptr[i] = 0; + for (auto &var : vars) { + auto &var_t = var->Get(); + PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims"); + out_ptr[i] += var_t.data()[i]; + } + } + } else if (var0->IsType()) { + auto *out_slr = out_var->GetMutable(); + out_slr->mutable_rows()->clear(); + out_slr->mutable_value()->mutable_data({{}}, cpu_place); + std::vector inputs; + inputs.reserve(vars.size()); + for (auto &var : vars) { + inputs.push_back(&var->Get()); + } + math::scatter::MergeAdd + merge_add; + auto dev_ctx = paddle::platform::CPUDeviceContext(); + merge_add(dev_ctx, inputs, out_slr, false); + } else { + PADDLE_THROW("unsupported var type!"); + } +} + +void Communicator::SendThread() { + while (running_) { + std::vector> task_futures; + task_futures.reserve(send_varname_to_ctx_.size()); + for (auto &iter : send_varname_to_queue_) { + auto send_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(3) << "merge var " << var_name << " and send"; + auto &var_queue = iter.second; + std::vector> vars; + // TODO(qiao): need to be configurable + const size_t max_merge_var_num = 20; + size_t merged_var_num = 0; + while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { + vars.push_back(var_queue->Pop()); + merged_var_num++; + } + MergeVars(var_name, vars, send_scope_.get()); + auto send_functor = distributed::ParameterSend(); + auto &ctx = send_varname_to_ctx_.at(var_name); + send_functor(ctx, *send_scope_, true); + }; + task_futures.emplace_back( + send_threadpool_->enqueue(std::move(send_task))); + } + for (auto &task_f : task_futures) { + task_f.wait(); + } + } +} + +void Communicator::RecvThread() { + while (running_) { + // parallel run recv graph + std::vector> task_futures; + task_futures.reserve(recv_varname_to_ctx_.size()); + for (auto &iter : recv_varname_to_ctx_) { + auto recv_task = [this, &iter] { + auto &var_name = iter.first; + VLOG(3) << "recv var " << var_name; + auto recv_functor = distributed::ParameterRecv(); + recv_functor(iter.second, *recv_scope_); + }; + task_futures.emplace_back( + recv_threadpool_->enqueue(std::move(recv_task))); + } + for (auto &task : task_futures) { + task.wait(); + } + } +} + +void Communicator::Send(const std::string &var_name, + const framework::Scope &scope) { + // push var into send queue by var_name + auto *grad_var = scope.FindVar(var_name); + PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); + auto tmp_grad_var = std::make_shared(); + framework::CopyVariable(*grad_var, tmp_grad_var.get()); + send_varname_to_queue_[var_name]->Push(tmp_grad_var); +} + +void Communicator::Start() { + running_ = true; + // start send and recv thread + send_thread_.reset( + new std::thread(std::bind(&Communicator::SendThread, this))); + recv_thread_.reset( + new std::thread(std::bind(&Communicator::RecvThread, this))); +} + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h new file mode 100644 index 0000000000000000000000000000000000000000..ffdfa38b12fc6240d31aeaa9d95fe498805548b9 --- /dev/null +++ b/paddle/fluid/operators/distributed/communicator.h @@ -0,0 +1,143 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace distributed { + +using Scope = framework::Scope; +using Variable = framework::Variable; + +template +class BlockingQueue { + public: + explicit BlockingQueue(size_t capacity) : capacity_(capacity) { + PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0."); + } + + bool Push(const T& elem) { + std::unique_lock lock(mutex_); + send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + PADDLE_ENFORCE_LT(queue_.size(), capacity_); + queue_.push_back(elem); + recv_cv_.notify_one(); + return true; + } + + bool Push(T&& elem) { + std::unique_lock lock(mutex_); + send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + PADDLE_ENFORCE_LT(queue_.size(), capacity_); + queue_.emplace_back(std::move(elem)); + recv_cv_.notify_one(); + return true; + } + + T Pop() { + std::unique_lock lock(mutex_); + recv_cv_.wait(lock, [=] { return !queue_.empty(); }); + T rc(std::move(queue_.front())); + queue_.pop_front(); + return rc; + } + + size_t Cap() const { + std::lock_guard lock(mutex_); + return capacity_; + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + private: + const size_t capacity_; + std::deque queue_; + + mutable std::mutex mutex_; + std::condition_variable recv_cv_; + std::condition_variable send_cv_; +}; + +class Communicator { + public: + Communicator( + const std::unordered_map& send_varname_to_ctx, + const std::unordered_map& recv_varname_to_ctx, + Scope* recv_scope) + : send_varname_to_ctx_(send_varname_to_ctx), + recv_varname_to_ctx_(recv_varname_to_ctx), + recv_scope_(recv_scope) { + // get all send information from graph, build vars_to_send + send_scope_.reset(new Scope()); + for (auto& iter : send_varname_to_ctx_) { + send_varname_to_queue_[iter.first] = + std::make_shared>>(10); + } + // TODO(qiao): default 5, need to config + send_threadpool_.reset(new ::ThreadPool(5)); + recv_threadpool_.reset(new ::ThreadPool(5)); + } + + ~Communicator() { + VLOG(3) << "~Communicator"; + running_ = false; + send_thread_->join(); + recv_thread_->join(); + VLOG(3) << "~Communicator done"; + } + + void Start(); + + // send grad + void Send(const std::string& var_name, const framework::Scope& scope); + + private: + void SendThread(); + void RecvThread(); + + bool running_ = false; + std::unordered_map>>> + send_varname_to_queue_; + std::unordered_map send_varname_to_ctx_; + std::unordered_map recv_varname_to_ctx_; + std::unique_ptr send_thread_; + std::unique_ptr recv_thread_; + Scope* recv_scope_; // should be global scope + std::unique_ptr send_scope_; // an independent scope + std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; + std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 4a9c158cb0ab7f2d6fecbba9f957ae6ef153074c..a0ed79201d5d95252b293f4c118f8b75e5cc5b25 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -107,6 +107,9 @@ class RequestSend final : public RequestBase { int trainer_id = request_->GetTrainerId(); framework::Variable* outvar = nullptr; + if (!request_handler_->sync_mode()) { + request_->ReleaseOwnershipOfLocalScope(); + } request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); Finish(reply_, &responder_); } diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index c63d65348880ebb4085d83059d9fead6456216d7..539a038099752d110f0576c1db7ef0bad75e733c 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -37,30 +37,9 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; -static size_t GetSectionIndex(int64_t id, - const std::vector& abs_sections) { - for (size_t i = 1; i < abs_sections.size(); ++i) { - if (id < abs_sections[i]) { - return i - 1; - } - } - return abs_sections.size() - 1; -} - -static std::vector ToAbsoluteSection( - const std::vector& height_sections) { - std::vector abs_sections; - abs_sections.resize(height_sections.size()); - abs_sections[0] = 0; - for (size_t i = 1; i < height_sections.size(); ++i) { - abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1]; - } - return abs_sections; -} - static std::vector> SplitIds( const std::vector& ids_vector, - const std::vector& height_section, framework::Scope* scope) { + const std::vector& height_section) { std::set all_ids; for (auto id : ids_vector) { all_ids.insert(id); @@ -78,7 +57,7 @@ static std::vector> SplitIds( static void SplitIdsIntoMultipleVarsBySection( const std::vector& in_var_names, - const std::vector& height_section, + const std::vector& height_section, const std::vector>& splited_ids, framework::Scope* scope) { PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); @@ -100,7 +79,7 @@ static void SplitIdsIntoMultipleVarsBySection( static void MergeMultipleVarsIntoOneBySection( const std::string& id_name, const std::vector& ids_vector, const std::string& out_name, const std::vector& out_var_names, - const std::vector& height_section, + const std::vector& height_section, const std::vector>& splited_ids, const framework::ExecutionContext& context, framework::Scope* scope, platform::DeviceContext* actual_ctx) { @@ -177,10 +156,10 @@ static void MergeMultipleVarsIntoOneBySection( void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, - const std::vector& height_sections, + const std::vector& height_sections, const framework::ExecutionContext& context, const framework::Scope& scope) { - auto& local_scope = scope.NewScope(); + framework::Scope* local_scope = scope.NewTmpScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); @@ -224,22 +203,22 @@ void prefetch(const std::string& id_name, const std::string& out_name, #endif } - auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope); + auto splited_ids = SplitIds(ids_vector, height_sections); SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, - &local_scope); + local_scope); // create output var in local scope for (auto& name : out_var_names) { - local_scope.Var(name)->GetMutable(); + local_scope->Var(name)->GetMutable(); } std::vector rets; for (size_t i = 0; i < in_var_names.size(); i++) { - if (NeedSend(local_scope, in_var_names[i])) { + if (NeedSend(*local_scope, in_var_names[i])) { VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i] << " to get " << out_var_names[i] << " back"; rets.push_back(rpc_client->AsyncPrefetchVar( - epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i], + epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i], table_names[i])); } else { VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; @@ -252,8 +231,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, out_var_names, height_sections, splited_ids, - context, &local_scope, &actual_ctx); - scope.DeleteScope(&local_scope); + context, local_scope, &actual_ctx); + delete local_scope; } }; // namespace distributed diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 2f850a0332256d458e79ed9da361c86eb8a2f780..0429ec4415dca19ff620cd7af5a8c0a935e17e2f 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -26,7 +26,7 @@ namespace distributed { void prefetch(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, - const std::vector& height_sections, + const std::vector& height_sections, const framework::ExecutionContext& context, const framework::Scope& scope); @@ -35,7 +35,7 @@ void prefetch_with_reconstruct(const std::string& id_name, const std::string& out_name, const std::vector& table_names, const std::vector& epmap, - const std::vector& height_sections, + const std::vector& height_sections, const framework::ExecutionContext& context, const framework::Scope& scope, framework::LoDTensor* original) { diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc new file mode 100644 index 0000000000000000000000000000000000000000..fecc76955de5c024d624e1cfba438901a886d1b8 --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/operators/distributed/parameter_recv.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor.h" + +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/variable_response.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" +#include "paddle/fluid/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { +namespace distributed { + +using LoDTensor = framework::LoDTensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template +void ParameterRecv::operator()(const RpcContext &rpc_ctx, + const framework::Scope &scope) { + framework::Scope *local_scope = scope.NewTmpScope(); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx = *pool.Get(platform::CPUPlace()); + + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(0); + + auto *recv_var = scope.FindVar(rpc_ctx.var_name); + + std::vector recved_tensors; + + // recv all vars to local scope + if (recv_var->IsType()) { + std::vector rets; + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + auto &recv_var_name = rpc_ctx.splited_var_names[i]; + framework::Tensor *t = + local_scope->Var(recv_var_name)->GetMutable(); + recved_tensors.push_back(t); + VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; + rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, + *local_scope, recv_var_name, + recv_var_name)); + } + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + } else { + PADDLE_THROW("unsupported var type to recv!"); + } + + // concat recved tensor into one var + { + size_t output_offset = 0; + framework::Tensor *recv_tensor = + recv_var->GetMutable(); + auto dev_ctx = paddle::platform::CPUDeviceContext(); + for (auto *in : recved_tensors) { + auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(recv_tensor->dims()); + StridedNumelCopyWithAxis( + dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, + in->data(), in_stride, in_stride[0]); + output_offset += in_stride[0]; + } + } + + delete local_scope; +} + +template struct ParameterRecv; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_recv.h b/paddle/fluid/operators/distributed/parameter_recv.h new file mode 100644 index 0000000000000000000000000000000000000000..e955fca7250ecc88f3b1a08611f380da50df788d --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_recv.h @@ -0,0 +1,34 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" + +namespace paddle { +namespace operators { +namespace distributed { + +template +struct ParameterRecv { + void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope); +}; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc new file mode 100644 index 0000000000000000000000000000000000000000..3fe3be193a3d01c07b0a7f1751a0d8cd526fd6d5 --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/operators/distributed/parameter_send.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor.h" + +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/rpc_client.h" +#include "paddle/fluid/operators/distributed/variable_response.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" + +namespace paddle { +namespace operators { +namespace distributed { + +using LoDTensor = framework::LoDTensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template +void ParameterSend::operator()(const RpcContext &rpc_ctx, + const framework::Scope &scope, bool sync) { + framework::Scope *local_scope = scope.NewTmpScope(); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &cpu_ctx = *pool.Get(platform::CPUPlace()); + + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(0); + + auto *send_var = scope.FindVar(rpc_ctx.var_name); + size_t out_num = rpc_ctx.splited_var_names.size(); + if (send_var->IsType()) { + if (out_num > 1) { + auto &send_tensor = send_var->Get(); + auto &send_tensor_dims = send_tensor.dims(); + std::vector outs_dims; + outs_dims.reserve(out_num); + + // infer output shape + PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num, + "tensor split sections size" + "should be equal to output size."); + for (size_t i = 0; i < out_num; ++i) { + auto dim = send_tensor_dims; + dim[0] = rpc_ctx.height_sections[i]; + outs_dims.push_back(dim); + } + + // create output var in local scope + size_t row_offset = 0; + for (auto i = 0; i < out_num; ++i) { + framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i]) + ->GetMutable(); + *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); + row_offset += outs_dims[i][0]; + } + } + } else if (send_var->IsType()) { + auto &send_slr = send_var->Get(); + auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); + + auto send_rows = send_slr.rows(); + std::vector> outs_rows_idx; + std::vector> outs_dense_idx; + + outs_rows_idx.resize(out_num); + outs_dense_idx.resize(out_num); + + auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; + auto src = send_slr.value().data(); + + // create output var in local scope + std::vector outs; + for (auto &name : rpc_ctx.splited_var_names) { + auto *out = local_scope->Var(name)->GetMutable(); + outs.push_back(out); + } + + // split rows index into output sparse vars + for (size_t i = 0; i < send_rows.size(); ++i) { + int out_idx = FindOutIdx(send_rows[i], abs_sections); + outs_rows_idx[out_idx].push_back(send_rows[i]); + outs_dense_idx[out_idx].push_back(i); + } + auto place = platform::CPUPlace(); + + for (size_t i = 0; i < outs_rows_idx.size(); ++i) { + auto rows_idx = outs_rows_idx[i]; + outs[i]->set_height(rpc_ctx.height_sections[i]); + auto dims = send_slr.GetCompleteDims(); + dims[0] = rows_idx.size(); + outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); + outs[i]->mutable_rows()->clear(); + if (rows_idx.size() > 0) { + for (auto idx : rows_idx) { + outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); + } + auto dst = outs[i]->mutable_value()->mutable_data(place); + for (size_t j = 0; j < rows_idx.size(); j++) { + if (platform::is_cpu_place(place)) { + memory::Copy( + platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), + src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); + } else { + PADDLE_THROW("do not support GPU now"); + /* + #ifdef PADDLE_WITH_CUDA + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CUDAPlace(), dst + j * row_numel, + platform::CUDAPlace(), + src + outs_dense_idx[i][j] * row_numel, + sizeof(T) * row_numel, stream); + #else + PADDLE_THROW("Paddle is not compiled with GPU"); + #endif + */ + } + } + } + PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(), + "rows should has the same size with tensor dim 0"); + } + + } else { + PADDLE_THROW("unsupported var type to send!"); + } + + std::vector rets; + for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { + auto &send_var_name = rpc_ctx.splited_var_names[i]; + auto &endpoint = rpc_ctx.epmap[i]; + if (NeedSend(*local_scope, send_var_name)) { + VLOG(3) << "sending " << send_var_name << " to " << endpoint; + rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope, + send_var_name)); + } else { + VLOG(3) << "don't send non-initialized variable: " + << rpc_ctx.splited_var_names[i]; + } + } + + // note!! only support sync send now + if (true || sync) { + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + } + + delete local_scope; +} + +template struct ParameterSend; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h new file mode 100644 index 0000000000000000000000000000000000000000..9077f4a4fb9fd9d7152e8be72519f16b1999e93d --- /dev/null +++ b/paddle/fluid/operators/distributed/parameter_send.h @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" + +namespace paddle { +namespace operators { +namespace distributed { + +template +struct ParameterSend { + void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope, + bool sync); +}; + +}; // namespace distributed +}; // namespace operators +}; // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 991158ac72007efc1233f852caed4f90f35fe1cd..e777d515ce9da7a434394bb0db031c10bed3b10d 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -71,13 +71,15 @@ class VarHandle { VarHandle(const std::string ep, const std::string& method, const std::string& name, const platform::DeviceContext* p_ctx = nullptr, - const framework::Scope* p_scope = nullptr) + const framework::Scope* p_scope = nullptr, + bool delete_local_scope = false) : status_(kDefaultState) { ep_ = ep; ctx_ = p_ctx; scope_ = p_scope; name_ = name; method_ = method; + delete_local_scope_ = delete_local_scope; } virtual ~VarHandle() {} @@ -99,6 +101,7 @@ class VarHandle { std::unique_lock lk(sync_mutex_); status_ = ok ? kFinishState : kErrorState; } + if (delete_local_scope_ && scope_) delete scope_; VLOG(7) << "VarHandle finish:" << ok; wait_cond_.notify_all(); } @@ -125,6 +128,7 @@ class VarHandle { std::string name_; // RPC method name. std::string method_; + bool delete_local_scope_; protected: std::mutex sync_mutex_; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index a1c5c0777402b808eed6306862fd6dd41b529dbd..e5318f98ca9844d7642eb688522380c4bcf347aa 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -59,9 +59,11 @@ bool RequestSendHandler::Handle(const std::string& varname, "async mode should not recv BATCH_BARRIER_MESSAGE or " "COMPLETE_MESSAGE"); } + try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); + delete scope; } catch (std::exception& e) { LOG(ERROR) << "async: run sub program error " << e.what(); return false; diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h new file mode 100644 index 0000000000000000000000000000000000000000..39eb2d078c832503a50a2b1812da4b80ba3ce71b --- /dev/null +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +namespace paddle { +namespace operators { +namespace distributed { + +struct RpcContext { + RpcContext(const std::string& name, const std::vector& names, + const std::vector& emap, + const std::vector& sections) + : var_name(name), + splited_var_names(names), + epmap(emap), + height_sections(sections) {} + + RpcContext(const RpcContext& ctx) { + var_name = ctx.var_name; + splited_var_names = ctx.splited_var_names; + epmap = ctx.epmap; + height_sections = ctx.height_sections; + } + + std::string var_name; + std::vector splited_var_names; + std::vector epmap; + std::vector height_sections; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index 294cae5f44a4701c064c3669af7b4138f68659e6..3ecb6960690cbd4679af9ad310f27a878ddb7a23 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -60,14 +60,12 @@ class VariableResponse { bool create_scope = false) : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { if (create_scope) { - local_scope_ = &scope->NewScope(); + local_scope_ = scope->NewTmpScope(); } } virtual ~VariableResponse() { - if (create_scope_) { - scope_->DeleteScope(local_scope_); - } + if (local_scope_) delete local_scope_; } int Parse(Source* source, const sendrecv::VariableMessage& meta) { @@ -86,6 +84,12 @@ class VariableResponse { inline std::string Varname() const { return meta_.varname(); } inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string TableName() const { return meta_.table_name(); } + inline void ReleaseOwnershipOfLocalScope() { + PADDLE_ENFORCE(create_scope_, + "only when create_scope_ is true can you release the " + "ownership of local scope"); + local_scope_ = nullptr; + } // should call parse first. framework::Variable* GetVar() { diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index a8bb597cbd59290df1347c164d37104c6ac431e9..3bcfc532e8665cf8226c85c2b10cf049feae748e 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 120c65f29699bf2745b09ea312d1de069c8173c5..680b484d413207595369dc665a1c64f115e4ae4b 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -20,6 +20,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/parameter_recv.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -48,32 +50,41 @@ class RecvOp : public framework::OperatorBase { distributed::RPCClient::GetInstance( Attr("trainer_id")); - if (with_barrier) { - std::vector rets; - for (size_t i = 0; i < outs.size(); i++) { - std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; - VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " - << varname << " and with AsyncGetVar"; - rets.push_back( - rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); - } - if (sync_mode) { + std::vector recv_varnames = + Attr>("recv_varnames"); + + if (recv_varnames.size() > 0) { + auto recv_functor = distributed::ParameterRecv(); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + recv_functor(rpc_ctx, scope); + } else { + if (with_barrier) { + std::vector rets; + for (size_t i = 0; i < outs.size(); i++) { + std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; + VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " + << varname << " and with AsyncGetVar"; + rets.push_back( + rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); + } + if (sync_mode) { + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + } + } else { + std::vector rets; + for (size_t i = 0; i < outs.size(); i++) { + std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; + VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " + << varname << " and with AsyncGetVarNoBarrier"; + rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope, + varname, outs[i])); + } for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); } } - } else { - std::vector rets; - for (size_t i = 0; i < outs.size(); i++) { - std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; - VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " - << varname << " and with AsyncGetVarNoBarrier"; - rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope, - varname, outs[i])); - } - for (size_t i = 0; i < rets.size(); i++) { - PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); - } } } }; @@ -110,6 +121,11 @@ This operator can get variables from server side. "for example: we need var named 'moment_1@127.0.0.1:1001', " "and it real name on parameter server is 'moment_1'. ") .SetDefault({}); + AddAttr>( + "recv_varnames", + "(vector) " + "the splited parameter varnames to be recved from pserver") + .SetDefault(std::vector{}); } }; diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index e2c2147ab5e9a76498a0fd9e1f18b75eed32e91e..8b09cf86d7d5f99944906ede5f801e4840eeca42 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -20,6 +20,8 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/parameter_send.h" +#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/platform/profiler.h" @@ -37,30 +39,43 @@ class SendOp : public framework::OperatorBase { const platform::Place& place) const override { auto ins = Inputs("X"); - std::vector epmap = Attr>("epmap"); + auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); - - distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); - - std::vector rets; - for (size_t i = 0; i < ins.size(); i++) { - if (NeedSend(scope, ins[i])) { - VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); - } else { - VLOG(3) << "don't send no-initialied variable: " << ins[i]; + auto send_varnames = Attr>("send_varnames"); + auto height_sections = Attr>("sections"); + + if (send_varnames.size() > 0) { + PADDLE_ENFORCE_EQ(ins.size(), 1, ""); + auto send_functor = distributed::ParameterSend(); + auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, + height_sections); + send_functor(rpc_ctx, scope, static_cast(sync_send)); + } else { + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance( + Attr("trainer_id")); + + std::vector rets; + for (size_t i = 0; i < ins.size(); i++) { + if (NeedSend(scope, ins[i])) { + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; + rets.push_back( + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); + } else { + VLOG(3) << "don't send no-initialied variable: " << ins[i]; + } } - } - if (sync_send) { - for (size_t i = 0; i < rets.size(); i++) { - VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; - PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); - VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; + if (sync_send) { + for (size_t i = 0; i < rets.size(); i++) { + VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; + } } } } @@ -88,6 +103,21 @@ This operator will send variables to listen_and_serve op at the parameter server "Server endpoints in the order of input " "variables for mapping") .SetDefault({"127.0.0.1:6164"}); + AddAttr>("sections", + "(vector) " + "the length of each output along the " + "specified axis.") + .SetDefault(std::vector{}); + AddAttr>( + "send_varnames", + "(vector) " + "the splited output varnames to send to pserver") + .SetDefault(std::vector{}); + AddAttr("num", + "(int, default 0)" + "Number of sub-tensors. This must evenly divide " + "Input.dims()[axis]") + .SetDefault(0); } }; diff --git a/paddle/fluid/operators/distributed_ops/send_recv_util.h b/paddle/fluid/operators/distributed_ops/send_recv_util.h index dc26c53c64f06ce21856fb5af8f2a5eb3fc75bb7..1e91f0dd51a8fe99042dd73b36ca44a1e64e7bd9 100644 --- a/paddle/fluid/operators/distributed_ops/send_recv_util.h +++ b/paddle/fluid/operators/distributed_ops/send_recv_util.h @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include +#include + #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" namespace paddle { namespace operators { @@ -42,5 +48,35 @@ inline bool NeedSend(const framework::Scope& scope, return false; } +inline int FindOutIdx(int row, const std::vector& abs_sections) { + for (size_t i = 1; i < abs_sections.size(); ++i) { + if (row < abs_sections[i]) { + return i - 1; + } + } + return abs_sections.size() - 1; +} + +inline std::vector ToAbsoluteSection( + const std::vector& height_sections) { + std::vector abs_sections; + abs_sections.resize(height_sections.size()); + abs_sections[0] = 0; + for (size_t i = 1; i < height_sections.size(); ++i) { + abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1]; + } + return abs_sections; +} + +inline size_t GetSectionIndex(int64_t id, + const std::vector& abs_sections) { + for (size_t i = 1; i < abs_sections.size(); ++i) { + if (id < abs_sections[i]) { + return i - 1; + } + } + return abs_sections.size() - 1; +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 6ca6f0bc04aa696852ed7338dcb4b88a49b2fc81..13820e54aa5cee179d73a8175307d46cc500b876 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -134,9 +134,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { // for parameter prefetch AddAttr("remote_prefetch", "").SetDefault(false); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); - AddAttr>("height_sections", - "Height for each output SelectedRows.") - .SetDefault(std::vector({})); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); AddAttr>( "epmap", "(string vector, default 127.0.0.1:6164)" diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 4d5a84bcafed1ab0739349e1dbc7b5a9f9ad64ec..751091478e286f8252d4b1bd19e1e9ae879dd62c 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -70,7 +70,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server - auto height_sections = ctx.Attr>("height_sections"); + auto height_sections = ctx.Attr>("height_sections"); auto table_names = ctx.Attr>("table_names"); std::vector real_rows = PathToRows(*path); framework::Scope& local_scope = ctx.scope().NewScope(); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 0029932bc068c7f61ddb41cf3f87c9e1a5cd7749..9f6fbe05facdec80e231c805f00387d1c30f4532 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -91,9 +91,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { // for parameter prefetch AddAttr("remote_prefetch", "").SetDefault(false); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); - AddAttr>("height_sections", - "Height for each output SelectedRows.") - .SetDefault(std::vector({})); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); AddAttr>( "epmap", "(string vector, default 127.0.0.1:6164)" diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 56c6e37ae3c62e1f9af66ef6ed16111dc1e93d9d..524565a439fec16e026f4b6414e78d0aff57ea29 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -50,7 +50,8 @@ class LookupTableKernel : public framework::OpKernel { // for remote prefetch auto epmap = context.Attr>("epmap"); - auto height_sections = context.Attr>("height_sections"); + auto height_sections = + context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); if (!epmap.empty()) { diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 222d761ef91d8aee4843d717dabba7edf131f8dc..db0ee9bc1695f7b1a55b4d111dc470b462210963 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -95,7 +95,7 @@ struct MergeAdd { enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; -// out = seleted_rows_in / tensor +// out = selected_rows_in / tensor template struct UpdateToTensor { void operator()(const DeviceContext& context, const ScatterOps& op, diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 256da34912560ddf1f7e430e8543efe00e5885bc..8160f45e74b8142b63f98b79e9a852c706cc9af0 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -156,9 +156,9 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { // for parameter prefetch AddAttr("remote_prefetch", "").SetDefault(false); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); - AddAttr>("height_sections", - "Height for each output SelectedRows.") - .SetDefault(std::vector({})); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); AddAttr>( "epmap", "(string vector, default 127.0.0.1:6164)" diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 3e48b67a570d41482e358ae3941eb1e2b6ab91f8..25b6ed851bc5149dcf6d25edc80544c99dd95d34 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -172,7 +172,8 @@ class NCEKernel : public framework::OpKernel { framework::Scope &local_scope = context.scope().NewScope(); - auto height_sections = context.Attr>("height_sections"); + auto height_sections = + context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); auto *ids = local_scope.Var("Ids@Prefetch"); diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index 1fef2b3d378c96d087118d0136885e7e29aa237c..c29065649e6ee0168ce44038b36dff25c62c082f 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -16,31 +16,12 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { -static int FindOutIdx(int row, const std::vector& abs_sections) { - for (size_t i = 1; i < abs_sections.size(); ++i) { - if (row < abs_sections[i]) { - return i - 1; - } - } - return abs_sections.size() - 1; -} - -static std::vector ToAbsoluteSection( - const std::vector& height_sections) { - std::vector abs_sections; - abs_sections.resize(height_sections.size()); - abs_sections[0] = 0; - for (size_t i = 1; i < height_sections.size(); ++i) { - abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1]; - } - return abs_sections; -} - template class SplitSelectedRowsOpKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 0968ace62b6a4e258f7763dbf6fbeda07feb4cd5..98e6923c1115142d3a9c90073898cd03db716398 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -48,6 +48,7 @@ class TestDistRunnerBase(object): # NOTE: import fluid until runtime, or else forking processes will cause error. config = fluid.DistributeTranspilerConfig() config.enable_dc_asgd = dc_asgd + config.runtime_split_send_recv = True t = fluid.DistributeTranspiler(config=config) t.transpile( trainer_id=trainer_id, @@ -87,6 +88,9 @@ class TestDistRunnerBase(object): args.endpoints, args.trainers, args.sync_mode, args.dc_asgd) trainer_prog = t.get_trainer_program() + with open("/tmp/trainer." + str(args.trainer_id) + ".proto", + "w") as f: + f.write(str(trainer_prog)) elif args.update_method == "nccl2": # transpile for nccl2 config = fluid.DistributeTranspilerConfig() @@ -115,6 +119,7 @@ class TestDistRunnerBase(object): strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() + build_stra.debug_graphviz_path = "/tmp/graph-" + str(args.trainer_id) if args.use_reduce: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -123,8 +128,7 @@ class TestDistRunnerBase(object): if args.batch_merge_repeat > 1: pass_builder = build_stra._finalize_strategy_and_create_passes() - mypass = pass_builder.insert_pass( - len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass") + mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass") mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index eb54068650e8b3f4e64317778e2ad7c7aa7fe1b2..4ddfc084e06b1dc1b018b479304d1e9c1871de19 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -156,6 +156,8 @@ class DistributeTranspilerConfig(object): mode = "pserver" print_log = False wait_port = True + # split the send recv var in runtime + runtime_split_send_recv = False class DistributeTranspiler(object): @@ -398,8 +400,10 @@ class DistributeTranspiler(object): orig_var = program.global_block().vars[splited_grad_varname] index = find_op_by_output_arg( program.global_block(), splited_grad_varname, reverse=True) - self._insert_split_op(program, orig_var, index, splited_vars) - index += 1 + if not self.config.runtime_split_send_recv: + self._insert_split_op(program, orig_var, index, + splited_vars) + index += 1 else: AssertionError("Can not insert the send op by original " "variable name :", splited_grad_varname) @@ -408,6 +412,17 @@ class DistributeTranspiler(object): name=framework.generate_control_dev_var_name()) self.grad_name_to_send_dummy_out[grad_varname] = dummy_output + if self.config.runtime_split_send_recv: + send_input_vars = [ + program.global_block().vars[splited_grad_varname] + ] + sections = self._get_splited_var_sections(splited_vars) + send_varnames = [var.name for var in splited_vars] + else: + send_input_vars = splited_vars + sections = [] + send_varnames = [] + # get send op_role_var, if not splited, the grad should have .trainer suffix # if splited, grad should be the original grad var name (split_by_ref and send # will be on the same place). ParallelExecutor @@ -415,10 +430,12 @@ class DistributeTranspiler(object): program.global_block()._insert_op( index=index + 1, type="send", - inputs={"X": splited_vars}, + inputs={"X": send_input_vars}, outputs={"Out": dummy_output}, attrs={ "epmap": eplist, + "sections": sections, + "send_varnames": send_varnames, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [ self.grad_name_to_param_name[grad_varname], @@ -501,13 +518,20 @@ class DistributeTranspiler(object): self._update_remote_sparse_update_op( param_varname, height_sections, eps, table_names) else: + recv_varnames = [] + if self.config.runtime_split_send_recv: + orig_param = program.global_block().vars[param_varname] + recv_varnames = [var.name for var in splited_var] + splited_var = [orig_param] all_recv_outputs.extend(splited_var) + program.global_block().append_op( type="recv", inputs={"X": [recv_dep_in]}, outputs={"Out": splited_var}, attrs={ "epmap": eps, + "recv_varnames": recv_varnames, "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: @@ -532,14 +556,15 @@ class DistributeTranspiler(object): continue orig_param = program.global_block().vars[param_varname] if param_varname not in self.sparse_param_to_height_sections: - program.global_block().append_op( - type="concat", - inputs={"X": splited_var}, - outputs={"Out": [orig_param]}, - attrs={ - "axis": 0, - RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE - }) + if not self.config.runtime_split_send_recv: + program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={ + "axis": 0, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) @@ -1376,9 +1401,8 @@ class DistributeTranspiler(object): # create table param and grad var in pserver program # create table optimize block in pserver program table_opt_op = [ - op for op in self.optimize_ops - if 'Param' in op.input_names and op.input("Param")[0] == - self.table_name + op for op in self.optimize_ops if 'Param' in op.input_names and + op.input("Param")[0] == self.table_name ][0] origin_param_var = self.origin_program.global_block().vars[ @@ -1552,11 +1576,17 @@ class DistributeTranspiler(object): lod_level=var.lod_level, persistable=persistable) + @staticmethod + def _get_splited_var_sections(splited_vars): + height_sections = [] + for v in splited_vars: + height_sections.append(v.shape[0]) + return height_sections + def _insert_split_op(self, program, orig_var, index, splited_vars): + height_sections = self._get_splited_var_sections(splited_vars) + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: - height_sections = [] - for v in splited_vars: - height_sections.append(v.shape[0]) sparse_param_name = self.grad_name_to_param_name[orig_var.name] if self._is_input_of_remote_sparse_update_op(sparse_param_name): self.sparse_param_to_height_sections[ @@ -1571,16 +1601,13 @@ class DistributeTranspiler(object): RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: - sections = [] - for v in splited_vars: - sections.append(v.shape[0]) program.global_block()._insert_op( index=index + 1, type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={ - "sections": sections, + "sections": height_sections, RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) else: @@ -2052,7 +2079,7 @@ class DistributeTranspiler(object): Get optimizer operators, parameters and gradients from origin_program Returns: opt_ops (list): optimize operators. - params_grads (dict): paramter->gradient. + params_grads (dict): parameter->gradient. """ block = self.origin_program.global_block() opt_ops = []