diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index e9aad5d264d1745662848d1ba313b573d0974cb7..7f63c07b18f7c6147670656dfc567f8f2ae8429a 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -64,9 +64,12 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("epmap")); auto height_section = boost::get>( node->Op()->GetNullableAttr("sections")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(send_var_name, send_varnames, - epmap, height_section); + epmap, height_section, + trainer_id); VLOG(3) << "find and init an send op: " << send_varname_to_ctx[send_var_name]; } else if (node->Name() == "recv") { @@ -75,9 +78,11 @@ void ProcessGraph(std::vector graphs, Scope *scope) { node->Op()->GetNullableAttr("recv_varnames")); auto epmap = boost::get>( node->Op()->GetNullableAttr("epmap")); + auto trainer_id = + boost::get(node->Op()->GetNullableAttr("trainer_id")); recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, - epmap, {}); + epmap, {}, trainer_id); nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 972b4f67a8388ce68952fa90aaa224cd45c6d226..f6531ec9edca7b425d28853f542d5e46783ba699 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -9,6 +9,9 @@ else() endif() configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY) +cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool) +cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) + # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_GRPC) @@ -20,7 +23,7 @@ if(WITH_GRPC) collective_client.cc collective_server.cc ${GRPC_SRCS} PROTO send_recv.proto - DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS}) + DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f3b6b959e30194c10b1a58d6fc3e7a61ad01313 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +namespace paddle { +namespace operators { +namespace distributed { + +std::once_flag AsyncSparseParamUpdateRecorder::init_flag_; +std::unique_ptr + AsyncSparseParamUpdateRecorder::recorder_(nullptr); + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..eadd842c7f6ead56006fd0c34814b1b7bd9b62f4 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -0,0 +1,183 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include + +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace distributed { + +class ConcurrentSet { + public: + ConcurrentSet() : pool_(new ::ThreadPool(1)) {} + ~ConcurrentSet() {} + + std::future Update(const std::vector& rows) { + auto task = [this, rows] { + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : rows) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "update ids -> " << sstream.str(); + } + for (auto row : rows) { + set_.insert(row); + } + }; + return pool_->enqueue(std::move(task)); + } + + std::future GetAndClear(std::vector* result) { + auto task = [this, &result] { + result->clear(); + for (auto& id : set_) { + result->push_back(id); + } + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& id : *result) { + sstream << id << ", "; + } + sstream << "]"; + VLOG(3) << "result ids size: " << result->size() << " " + << sstream.str(); + } + set_.clear(); + }; + return pool_->enqueue(std::move(task)); + } + + private: + std::unordered_set set_; + std::unique_ptr<::ThreadPool> pool_{nullptr}; +}; + +class AsyncSparseParamUpdateRecorder { + using TrainerToRows = std::vector>; + + public: + AsyncSparseParamUpdateRecorder( + int trainer_num, + const std::unordered_map& grad_to_param) + : trainer_num_(trainer_num), grad_to_param_(grad_to_param) { + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& item : grad_to_param) { + sstream << item.first << ":" << item.second << ", "; + } + sstream << "]"; + VLOG(3) << "trainer_num: " << trainer_num + << " grad_to_param_: " << sstream.str(); + } + for (auto& iter : grad_to_param) { + param_to_grad_[iter.second] = iter.first; + auto& param_name = iter.second; + param_to_updated_rows_[param_name] = TrainerToRows(); + auto& trainer_to_rows = param_to_updated_rows_[param_name]; + for (auto i = 0; i < trainer_num; ++i) { + trainer_to_rows.emplace_back(new ConcurrentSet()); + } + } + } + + ~AsyncSparseParamUpdateRecorder() = default; + + void Update(const std::string& grad_name, + const std::vector& update_rows) { + VLOG(3) << "update grad: " << grad_name + << " row size: " << update_rows.size(); + auto& param_name = grad_to_param_.at(grad_name); + auto& trainer_to_rows = param_to_updated_rows_.at(param_name); + + std::vector> fs; + for (auto& set : trainer_to_rows) { + fs.push_back(set->Update(update_rows)); + } + for (auto& f : fs) { + f.wait(); + } + } + + void GetAndClear(const std::string& param_name, int trainer_id, + std::vector* result) { + VLOG(3) << "GetAndClear param: " << param_name + << " for trainer: " << trainer_id; + PADDLE_ENFORCE_LT(trainer_id, trainer_num_); + param_to_updated_rows_.at(param_name)[trainer_id] + ->GetAndClear(result) + .wait(); + } + + bool HasParam(const std::string& param_name) { + return param_to_grad_.find(param_name) != param_to_grad_.end(); + } + + bool HasGrad(const std::string& grad_name) { + return grad_to_param_.find(grad_name) != grad_to_param_.end(); + } + + private: + const int trainer_num_; + std::unordered_map grad_to_param_; + std::unordered_map param_to_grad_; + std::unordered_map param_to_updated_rows_; + + // init recorder + public: + static void Init( + int trainer_num, + const std::unordered_map& grad_to_param) { + InitImpl(trainer_num, grad_to_param); + } + + static AsyncSparseParamUpdateRecorder* GetInstance() { + return recorder_.get(); + } + + private: + // Init is called by GetInstance. + static void InitImpl( + int trainer_num, + const std::unordered_map& grad_to_param) { + if (recorder_ == nullptr) { + recorder_.reset( + new AsyncSparseParamUpdateRecorder(trainer_num, grad_to_param)); + } + } + + static std::once_flag init_flag_; + static std::unique_ptr recorder_; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..67e8fd8a0edc4510d0abe885c821e75b528254f8 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +#include + +#include "gtest/gtest.h" + +namespace paddle { +namespace operators { +namespace distributed { + +TEST(ConcurrentSet, All) { + ConcurrentSet concurrent_set; + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::vector> futures; + futures.push_back(concurrent_set.Update(in1)); + futures.push_back(concurrent_set.Update(in2)); + + for (auto &f : futures) { + f.wait(); + } + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + std::vector ret; + concurrent_set.GetAndClear(&ret).wait(); + + std::unordered_set out; + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + concurrent_set.GetAndClear(&ret).wait(); + EXPECT_EQ(ret.size(), 0); +} + +TEST(AsyncSparseParamUpdateRecorder, All) { + std::unordered_map grad_to_param; + grad_to_param["grad1"] = "param1"; + grad_to_param["grad2"] = "param2"; + + int trainer_num = 10; + + AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param); + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + recorder.Update("grad1", in1); + recorder.Update("grad1", in2); + + EXPECT_TRUE(recorder.HasParam("param1")); + EXPECT_TRUE(recorder.HasParam("param2")); + EXPECT_FALSE(recorder.HasParam("param3")); + + EXPECT_TRUE(recorder.HasGrad("grad1")); + EXPECT_TRUE(recorder.HasGrad("grad2")); + EXPECT_FALSE(recorder.HasGrad("grad3")); + + std::vector ret; + EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret)); + + for (int i = 0; i < trainer_num; ++i) { + std::vector ret; + std::unordered_set out; + + recorder.GetAndClear("param1", i, &ret); + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + recorder.GetAndClear("param1", i, &ret); + EXPECT_EQ(ret.size(), 0); + } +} + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.cc b/paddle/fluid/operators/distributed/brpc/brpc_client.cc index a1a3443348129b5cdf057592fced8fdff238ac09..4c22ad8eb4d4b2e23d8a6720e726eb9e2998314e 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.cc @@ -234,6 +234,7 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC, time_out); diff --git a/paddle/fluid/operators/distributed/brpc/brpc_client.h b/paddle/fluid/operators/distributed/brpc/brpc_client.h index 501a593b11d35c160348e42ee47216a85647aac4..51864dfdca53eb4b1d9045188a6347781130e785 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_client.h +++ b/paddle/fluid/operators/distributed/brpc/brpc_client.h @@ -21,8 +21,10 @@ limitations under the License. */ #include #include #include +#include #include // NOLINT #include +#include #include #include "brpc/channel.h" @@ -66,6 +68,7 @@ class BRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_var_name, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetMonomerBarrier( @@ -107,13 +110,11 @@ class BRPCClient : public RPCClient { void SendComplete() override; private: - VarHandlePtr _AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& out_var_name, - const std::string& method_name, - int64_t time_out = FLAGS_rpc_deadline); + VarHandlePtr _AsyncGetVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_var_name, const std::string& method_name, + const std::string& table_name, int64_t time_out = FLAGS_rpc_deadline); void Proceed(); ChannelQueuePtr GetChannel(const std::string& ep); diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index eba18c67771fa26eed855b0f19591e06101f424d..b528bcdd32b11d686f44596d9a1bb663b21691f4 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -32,6 +32,9 @@ DEFINE_int32(communicator_send_queue_size, 20, DEFINE_int32(communicator_max_send_grad_num_before_recv, 20, "max grad num to send before recv parameters"); DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv"); +DEFINE_int32(communicator_send_wait_times, 5, + "times that send thread will wait if merge num does not reach " + "max_merge_var_num"); DEFINE_int32(communicator_max_merge_var_num, 20, "max var num to merge and send"); DEFINE_bool(communicator_fake_rpc, false, @@ -65,6 +68,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, << FLAGS_communicator_max_send_grad_num_before_recv; VLOG(0) << "communicator_thread_pool_size: " << FLAGS_communicator_thread_pool_size; + VLOG(0) << "communicator_send_wait_times: " + << FLAGS_communicator_send_wait_times; VLOG(0) << "communicator_max_merge_var_num: " << FLAGS_communicator_max_merge_var_num; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; @@ -101,20 +106,32 @@ void Communicator::SendThread() { VLOG(3) << var_name << " merge and send"; std::vector> vars; size_t merged_var_num = 0; - while (var_queue->Size() > 0 && - merged_var_num < FLAGS_communicator_max_merge_var_num) { - vars.push_back(var_queue->Pop()); - // only count the send number of the first var - if (var_name == send_varname_to_queue_.begin()->first) { - grad_num_.fetch_add(1, std::memory_order_relaxed); + size_t wait_times = 0; + while (merged_var_num < FLAGS_communicator_max_merge_var_num) { + if (var_queue->Size() == 0) { + VLOG(3) << "wait_times -> " << wait_times; + if (wait_times >= FLAGS_communicator_send_wait_times) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + + vars.push_back(var_queue->Pop()); + // only count the send number of the first var + if (var_name == send_varname_to_queue_.begin()->first) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + merged_var_num++; } - merged_var_num++; } auto before_merge = GetCurrentUS(); MergeVars(var_name, vars, send_scope_.get()); auto after_merge = GetCurrentUS(); - VLOG(3) << "merge " << var_name << " use time " - << after_merge - before_merge; + VLOG(3) << "merge " << merged_var_num << " " << var_name + << " use time " << after_merge - before_merge; auto send_functor = distributed::ParameterSend(); auto &ctx = send_varname_to_ctx_.at(var_name); if (!FLAGS_communicator_fake_rpc) { diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 41155bfc31bb31520fdcf5bd50b203f2e1f2c516..37c39eb15112f745f6a25e95ce65d431d825182e 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -109,7 +109,7 @@ inline void MergeVars(const std::string& var_name, auto* out_var = scope->Var(var_name); if (var0->IsType()) { auto dims = var0->Get().dims(); - VLOG(3) << "merge " << var_name << " LoDTensor " << dims; + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims; // init output tensor auto* out_t = out_var->GetMutable(); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 61e94dae3c7a107e10fa5e5518651014cec078bc..8504110c6e9dbfe22b78063999ed4a9e36850e2c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, - "/sendrecv.SendRecvService/GetVariable", time_out); + "/sendrecv.SendRecvService/GetVariable", table_name, + time_out); } VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( @@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( return _AsyncGetVar( ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, - "/sendrecv.SendRecvService/GetVariableNoBarrier", time_out); + "/sendrecv.SendRecvService/GetVariableNoBarrier", "", time_out); } VarHandlePtr GRPCClient::AsyncGetMonomerVariable( @@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable( const framework::Scope& scope, const std::string& var_name, int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, - "/sendrecv.SendRecvService/GetMonomerVariable", time_out); + "/sendrecv.SendRecvService/GetMonomerVariable", "", + time_out); } VarHandlePtr GRPCClient::_AsyncGetVar( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out) { + const std::string& rpc_path, const std::string& table_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const std::string out_varname_val = out_varname; + const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); @@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar( VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO( - [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); + framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method, + p_ctx, h, rpc_path, this] { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(out_varname_val); + req.set_trainer_id(trainer_id_); + req.set_table_name(table_name_val); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method); + platform::RecordRPCEvent record_event(method); - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + auto call = + s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); req_count_++; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index ce0d2152aa27c62b6e12881aaf2ae458597e67e6..ad2f04a6d1dda34e35b67b21dce8ac612ff697a0 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -23,9 +23,11 @@ limitations under the License. */ #include #include #include +#include #include // NOLINT #include #include // NOLINT +#include #include #include "grpc++/channel.h" @@ -187,6 +189,7 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetVarNoBarrier( @@ -239,7 +242,8 @@ class GRPCClient : public RPCClient { const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& method, const std::string& var_name, const std::string& out_varname, - const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline); + const std::string& rpc_path, const std::string& table_name = "", + int64_t time_out = FLAGS_rpc_deadline); private: grpc::CompletionQueue cq_; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 0eb313f75dfa64f8722faa365128f3111f72bd0b..75526bed0f0eadada65279ec05757da7a469f984 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -137,6 +137,7 @@ class RequestGet final : public RequestBase { // proc request. std::string varname = request_.varname(); std::string out_varname = request_.out_varname(); + std::string table_name = request_.table_name(); int trainer_id = request_.trainer_id(); VLOG(4) << "RequestGet " << out_varname << " from " << varname; @@ -145,19 +146,23 @@ class RequestGet final : public RequestBase { framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id, - out_varname); + tmp_scope_ = std::move(scope->NewTmpScope()); + request_handler_->Handle(varname, tmp_scope_.get(), invar, &outvar, + trainer_id, out_varname, table_name); + VLOG(1) << "before SerializeToByteBuffer"; if (outvar) { SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), &reply_); } + VLOG(1) << "after SerializeToByteBuffer"; Finish(reply_, &responder_); } protected: sendrecv::VariableMessage request_; ::grpc::ByteBuffer reply_; + std::unique_ptr tmp_scope_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; }; diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index e7d4c262aa9fad10a23adc61b94ba0c38577c0e8..da73167ae603fb8c8ba9deabe118269891d1f52a 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -42,27 +42,23 @@ using DDim = framework::DDim; template void ParameterRecv::operator()(const RpcContext &rpc_ctx, const framework::Scope &scope) { - VLOG(3) << "ParameterRecv in"; + VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; std::unique_ptr local_scope = scope.NewTmpScope(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *recv_var = scope.FindVar(rpc_ctx.var_name); - std::vector recved_tensors; - // recv all vars to local scope if (recv_var->IsType()) { std::vector rets; for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { auto &recv_var_name = rpc_ctx.splited_var_names[i]; - framework::Tensor *t = - local_scope->Var(recv_var_name)->GetMutable(); - recved_tensors.push_back(t); + local_scope->Var(recv_var_name); VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, @@ -78,23 +74,61 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, // concat recved tensor into one var { size_t output_offset = 0; + size_t row_offset = 0; framework::Tensor *recv_tensor = recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); int64_t recv_numel = 0; - for (auto *in : recved_tensors) { - recv_numel += in->numel(); - auto in_stride = framework::stride_numel(in->dims()); - auto out_stride = framework::stride_numel(recv_tensor->dims()); - StridedNumelCopyWithAxis( - dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, - in->data(), in_stride, in_stride[0]); - output_offset += in_stride[0]; + for (auto &recv_var_name : rpc_ctx.splited_var_names) { + auto *recv_var = local_scope->FindVar(recv_var_name); + if (recv_var->IsType()) { + auto &in = recv_var->Get(); + recv_numel += in.numel(); + auto in_stride = framework::stride_numel(in.dims()); + auto out_stride = framework::stride_numel(recv_tensor->dims()); + StridedNumelCopyWithAxis( + dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, + in.data(), in_stride, in_stride[0]); + output_offset += in_stride[0]; + } else if (recv_var->IsType()) { + auto &recv_slr = recv_var->Get(); + auto &recv_dims = recv_tensor->dims(); + int64_t width = recv_dims[1]; + recv_numel += recv_slr.height() * width; + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width); + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size()); + VLOG(3) << "recv slr " << recv_var_name << " dims " + << recv_slr.value().dims(); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto &row_id : recv_slr.rows()) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " " + << sstream.str(); + } + + for (auto i = 0; i < recv_slr.rows().size(); ++i) { + auto row_id = recv_slr.rows()[i] + row_offset; + PADDLE_ENFORCE_LT(row_id, recv_dims[0]); + memcpy(recv_tensor->data() + row_id * width, + recv_slr.value().data() + i * width, sizeof(T) * width); + } + row_offset += recv_slr.height(); + } else { + PADDLE_THROW("unsupported recieved var type"); + } + } + auto numel = recv_tensor->numel(); + if (recv_numel != numel) { + LOG(FATAL) << "recv_numel: " << recv_numel << " acture numel: " << numel; } - PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); + PADDLE_ENFORCE_EQ(recv_numel, numel); } - VLOG(3) << "ParameterRecv out"; + VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; } template struct ParameterRecv; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 9ce424445229cde0a7e775c95f4af8839f4d4d68..dfabad567af590b65b9e777824d476fce2b17238 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -47,7 +47,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto &cpu_ctx = *pool.Get(platform::CPUPlace()); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance(0); + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); auto *send_var = scope.FindVar(rpc_ctx.var_name); size_t out_num = rpc_ctx.splited_var_names.size(); diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 991158ac72007efc1233f852caed4f90f35fe1cd..de8f30184611aeb961e2ab69b05779c56371b976 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -18,7 +18,9 @@ #include // NOLINT #include +#include #include +#include #include #include @@ -180,6 +182,10 @@ class RequestHandler { grad_to_prepared_ctx_ = g; } + void SetSparseGradToParam(std::unordered_map* g) { + sparse_grad_to_param_ = g; + } + void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; } // Get attributes. @@ -228,6 +234,7 @@ class RequestHandler { std::unordered_map>* grad_to_prepared_ctx_; + std::unordered_map* sparse_grad_to_param_; RPCServer* rpc_server_; }; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index e289ec929dbd6643a2518b92c1a25b7d63e790a9..a41536368abc925531d1a54615546a100482a7eb 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/printf.h" @@ -59,6 +60,12 @@ bool RequestSendHandler::Handle(const std::string& varname, "async mode should not recv BATCH_BARRIER_MESSAGE or " "COMPLETE_MESSAGE"); } + if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(varname)) { + auto& grad_slr = + scope->FindVar(varname)->Get(); + AsyncSparseParamUpdateRecorder::GetInstance()->Update(varname, + grad_slr.rows()); + } executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); return true; @@ -82,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname, const int trainer_id, const std::string& out_var_name, const std::string& table_name) { - VLOG(4) << "RequestGetHandler:" << varname - << " out_var_name: " << out_var_name; + VLOG(3) << "RequestGetHandler:" << varname + << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id + << " table_name: " << table_name; if (sync_mode_) { if (varname == FETCH_BARRIER_MESSAGE) { @@ -108,7 +116,42 @@ bool RequestGetHandler::Handle(const std::string& varname, VLOG(3) << "copying " << varname << " to " << param_bak_name; framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); } - *outvar = scope_->FindVar(varname); + if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && + !table_name.empty()) { + std::vector updated_rows; + AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( + varname, trainer_id, &updated_rows); + if (VLOG_IS_ON(3)) { + std::ostringstream sstream; + sstream << "["; + for (auto& row_id : updated_rows) { + sstream << row_id << ", "; + } + sstream << "]"; + VLOG(3) << "updated_rows size: " << updated_rows.size() << " " + << sstream.str(); + } + auto& origin_tensor = + scope_->FindVar(varname)->Get(); + auto* origin_tensor_data = origin_tensor.data(); + auto& dims = origin_tensor.dims(); + *outvar = scope->Var(); + auto* out_slr = (*outvar)->GetMutable(); + out_slr->set_rows(updated_rows); + out_slr->set_height(dims[0]); + auto out_dims = framework::make_ddim( + {static_cast(updated_rows.size()), dims[1]}); + auto* data = out_slr->mutable_value()->mutable_data( + out_dims, origin_tensor.place()); + auto width = dims[1]; + for (auto i = 0; i < updated_rows.size(); ++i) { + PADDLE_ENFORCE_LT(updated_rows[i], dims[0]); + memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width, + sizeof(float) * width); + } + } else { + *outvar = scope_->FindVar(varname); + } } } return true; diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index ea54e0c2951253fc009672f4cd2e5233ed56944e..d4be2c28fdbaa4beef62402155de5b677ed67e9b 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -15,6 +15,7 @@ #pragma once #include // NOLINT +#include #include #include "gflags/gflags.h" @@ -44,6 +45,7 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, const std::string& out_varname, + const std::string& table_name = "", int64_t time_out = FLAGS_rpc_deadline) = 0; virtual VarHandlePtr AsyncGetVarNoBarrier( @@ -96,6 +98,7 @@ class RPCClient { // Init is called by GetInstance. template static void Init(int trainer_id) { + VLOG(0) << "init rpc client with trainer_id " << trainer_id; trainer_id_ = trainer_id; if (rpc_client_.get() == nullptr) { rpc_client_.reset(new T()); diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 3de89c2ae89d29edc317ca123882d1c55038b6ca..eb127bf4ad5a5c9a28210e2fbcdb69b07543f4b9 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -27,23 +27,26 @@ struct RpcContext { RpcContext(const std::string &name, const std::vector &names, const std::vector &emap, - const std::vector §ions) + const std::vector §ions, int id) : var_name(name), splited_var_names(names), epmap(emap), - height_sections(sections) {} + height_sections(sections), + trainer_id(id) {} RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; splited_var_names = ctx.splited_var_names; epmap = ctx.epmap; height_sections = ctx.height_sections; + trainer_id = ctx.trainer_id; } std::string var_name; std::vector splited_var_names; std::vector epmap; std::vector height_sections; + int trainer_id; }; inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index a1ef1af39ff2ab1456706ebafbd3d7ce1acc0c07..1096f3773c6d44560d370502b1c550d67d40ca64 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 5b30ed472d51a37a0705d1717395da9e4ff7d743..a672fb2a9141a81383d947dcc961a112aee3f7ac 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -24,8 +24,10 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" + #include "paddle/fluid/platform/profiler.h" DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send"); @@ -292,6 +294,8 @@ static void FillRequestCtx( std::unordered_map> *prefetch_ctx, + std::unordered_map + *sparse_grad_name_to_param_name, std::shared_ptr checkpoint_ctx, distributed::RPCServer *rpc_server) { h->SetScope(scope); @@ -299,6 +303,7 @@ static void FillRequestCtx( h->SetExecutor(executor); h->SetProgram(program); h->SetPrefetchPreparedCtx(prefetch_ctx); + h->SetSparseGradToParam(sparse_grad_name_to_param_name); h->SetRPCServer(rpc_server); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } @@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - auto f = - std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, - &executor, program, &prefetch_var_name_to_prepared_ctx, - ckpt_pre_context, rpc_service_.get()); + // parse attr of kSparseGradToParam sparse_grad_name -> param_name + std::unordered_map sparse_grad_name_to_param_name; + auto sparse_grad_name_to_param_name_str = + Attr>(kSparseGradToParam); + for (const auto &sparse_grad_name_and_param_name : + sparse_grad_name_to_param_name_str) { + std::vector pieces; + split(sparse_grad_name_and_param_name, ':', &pieces); + PADDLE_ENFORCE_EQ(pieces.size(), 2); + VLOG(3) << "after split, sparse_grad_name = " << pieces[0] + << ", param_name = " << pieces[1]; + sparse_grad_name_to_param_name[pieces[0]] = pieces[1]; + } + + auto f = std::bind( + FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, &executor, + program, &prefetch_var_name_to_prepared_ctx, + &sparse_grad_name_to_param_name, ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); @@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, prefetch_block_id_list, checkpoint_block_id); } else { + distributed::AsyncSparseParamUpdateRecorder::Init( + fan_in, sparse_grad_name_to_param_name); RunAsyncLoop(&executor, program, &recv_scope); } } @@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>(kPrefetchVarNameToBlockId, "prefetch blocks to run on server side.") .SetDefault({}); + AddAttr>( + kSparseGradToParam, + "sparse grad name to param name. like: 'emb@Grad:emb'") + .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); AddAttr(kCheckpointBlockId, diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h index f20442bad7c5bd96173b9d6efc4dceb13feacf5b..1cf2130d7a593077d1145b4f3be379c32557dd53 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h @@ -16,8 +16,10 @@ limitations under the License. */ #include #include +#include #include #include +#include #include #include @@ -35,6 +37,7 @@ namespace operators { constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kCheckpointBlockId[] = "checkpint_block_id"; +constexpr char kSparseGradToParam[] = "sparse_grad_to_param"; void RunServer(std::shared_ptr service); diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 3fd0700a077321d931e87b1d94c3637d167c9eff..8e9846b1fc89953526149be3838103526d5c441b 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(place); + auto trainer_id = Attr("trainer_id"); distributed::RPCClient *rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector recv_varnames = Attr>("recv_varnames"); if (recv_varnames.size() > 0) { auto recv_functor = distributed::ParameterRecv(); - auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); + auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}, + trainer_id); recv_functor(rpc_ctx, scope); } else { if (with_barrier) { diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index b08cd0942f8c89b60d722c931d0cec2063b96578..5731bcc15a07074b3d77873c5cdcbb70dc41aba8 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase { auto epmap = Attr>("epmap"); int sync_send = Attr("sync_mode"); + auto trainer_id = Attr("trainer_id"); auto send_varnames = Attr>("send_varnames"); auto height_sections = Attr>("sections"); @@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase { if (distributed::Communicator::GetInstance() == nullptr) { auto send_functor = distributed::ParameterSend(); auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); + height_sections, trainer_id); send_functor(rpc_ctx, scope, true); } else { distributed::Communicator::GetInstance()->Send(ins[0], scope); @@ -62,8 +63,7 @@ class SendOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); distributed::RPCClient* rpc_client = - distributed::RPCClient::GetInstance( - Attr("trainer_id")); + distributed::RPCClient::GetInstance(trainer_id); std::vector rets; for (size_t i = 0; i < ins.size(); i++) { diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 983d8243b1d8aa6c8d01855d6dbeab76c335f70c..3dc2b0c895116155f41df3ca66125fff3ede5ead 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -175,6 +175,7 @@ def __bootstrap__(): read_env_flags.append('communicator_thread_pool_size') read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_fake_rpc') + read_env_flags.append('communicator_send_wait_times') if core.is_compiled_with_brpc(): read_env_flags.append('max_body_size') #set brpc max body size diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 41e5f47976c566306ad141f655a0f6516831d690..19a1f8bf74060905ecb4b81b44f7080db79c45e4 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -658,6 +658,7 @@ class DistributeTranspiler(object): outputs={"Out": splited_var}, attrs={ "epmap": eps, + "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) @@ -669,6 +670,7 @@ class DistributeTranspiler(object): outputs={"Out": fetch_barrier_out}, attrs={ "endpoints": self.pserver_endpoints, + "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) @@ -791,11 +793,15 @@ class DistributeTranspiler(object): global_ops = [] + # sparse grad name to param name + sparse_grad_to_param = [] + def __append_optimize_op__(op, block, grad_to_block_id, merged_var, lr_ops): if self._is_optimizer_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, - self.origin_program, merged_var) + self.origin_program, merged_var, + sparse_grad_to_param) elif op not in lr_ops: self._append_pserver_non_opt_ops(block, op) @@ -911,6 +917,7 @@ class DistributeTranspiler(object): "Fanin": self.trainer_num, "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, + "sparse_grad_to_param": sparse_grad_to_param, } if self.has_distributed_lookup_table: @@ -1779,7 +1786,8 @@ class DistributeTranspiler(object): return o4 def _append_pserver_ops(self, optimize_block, opt_op, endpoint, - grad_to_block_id, origin_program, merged_var): + grad_to_block_id, origin_program, merged_var, + sparse_grad_to_param): program = optimize_block.program pserver_block = program.global_block() new_inputs = collections.OrderedDict() @@ -1863,6 +1871,12 @@ class DistributeTranspiler(object): outputs=outputs, attrs=opt_op.all_attrs()) + # record sparse grad to param name + if new_inputs["Grad"].type == core.VarDesc.VarType.SELECTED_ROWS: + sparse_grad_to_param.append( + str(new_inputs["Grad"].name) + ":" + str(new_inputs["Param"] + .name)) + def _get_pserver_grad_param_var(self, var, var_dict): """ Return pserver side grad/param variable, return None