From 255b36dad2a3500a108977cee2b5eb041b086d2b Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 6 Mar 2019 14:39:14 +0800 Subject: [PATCH] can run --- .../details/async_ssa_graph_executor.cc | 13 +++++-- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/communicator.cc | 6 ++++ .../operators/distributed/communicator.h | 2 +- .../fluid/operators/distributed/rpc_common.h | 36 ++++++++++++++++--- .../operators/distributed_ops/CMakeLists.txt | 4 +-- .../operators/distributed_ops/send_op.cc | 11 +++--- 7 files changed, 60 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 43391804c54..18fba0d19bb 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -59,6 +59,8 @@ void ProcessGraph(std::vector graphs, Scope *scope) { send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(send_var_name, send_varnames, epmap, height_section); + VLOG(3) << "find and init an send op: " + << send_varname_to_ctx[send_var_name]; } else if (node->Op()->Type() == "recv") { auto recv_var_name = node->Op()->Input("X")[0]; auto recv_varnames = boost::get>( @@ -68,13 +70,19 @@ void ProcessGraph(std::vector graphs, Scope *scope) { recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(recv_var_name, recv_varnames, epmap, {}); + graphs[i]->RemoveNode(node); + VLOG(3) << "find and remove an recv op: " + << recv_varname_to_ctx[recv_var_name]; } } } } // init communicator here - operators::distributed::Communicator::Init(send_varname_to_ctx, - recv_varname_to_ctx, scope); + if (send_varname_to_ctx.size() > 0) { + VLOG(3) << "this is distribute mode, will use "; + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, scope); + } } AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( @@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( for (auto *scope : local_scopes_) { NewTempScopeAndInitVars(var_infos_, scope); } + ProcessGraph(graphs_, local_scopes_[0]); } void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 1301467fa74..6a269a4fbe6 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -30,7 +30,7 @@ if(WITH_GRPC) else() set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) - set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index a88b7644748..e800cd5f417 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -63,6 +63,9 @@ static inline void MergeVars(const std::string &var_name, } } +std::unique_ptr Communicator::communicator_(nullptr); +std::once_flag Communicator::init_flag_; + void Communicator::SendThread() { while (running_) { std::vector> task_futures; @@ -117,6 +120,7 @@ void Communicator::RecvThread() { void Communicator::Send(const std::string &var_name, const framework::Scope &scope) { + VLOG(3) << "communicator send " << var_name; // push var into send queue by var_name auto *grad_var = scope.FindVar(var_name); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); @@ -125,6 +129,8 @@ void Communicator::Send(const std::string &var_name, send_varname_to_queue_[var_name]->Push(tmp_grad_var); } +Communicator *Communicator::GetInstance() { return communicator_.get(); } + void Communicator::Start() { running_ = true; // start send and recv thread diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 44e2aa3be73..bc753bb75ef 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -144,7 +144,7 @@ class Communicator { InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope); } - static Communicator* GetInstance() { return communicator_.get(); } + static Communicator* GetInstance(); private: // Init is called by GetInstance. diff --git a/paddle/fluid/operators/distributed/rpc_common.h b/paddle/fluid/operators/distributed/rpc_common.h index 39eb2d078c8..3de89c2ae89 100644 --- a/paddle/fluid/operators/distributed/rpc_common.h +++ b/paddle/fluid/operators/distributed/rpc_common.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include @@ -22,15 +23,17 @@ namespace operators { namespace distributed { struct RpcContext { - RpcContext(const std::string& name, const std::vector& names, - const std::vector& emap, - const std::vector& sections) + RpcContext() = default; + + RpcContext(const std::string &name, const std::vector &names, + const std::vector &emap, + const std::vector §ions) : var_name(name), splited_var_names(names), epmap(emap), height_sections(sections) {} - RpcContext(const RpcContext& ctx) { + RpcContext(const RpcContext &ctx) { var_name = ctx.var_name; splited_var_names = ctx.splited_var_names; epmap = ctx.epmap; @@ -43,6 +46,31 @@ struct RpcContext { std::vector height_sections; }; +inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { + os << "{"; + os << "var_name: " << rpc_ctx.var_name << "\n"; + + os << "splited_var_names: ["; + for (auto &name : rpc_ctx.splited_var_names) { + os << name << ", "; + } + os << "]\n"; + + os << "epmap: ["; + for (auto &ep : rpc_ctx.epmap) { + os << ep << ", "; + } + os << "]\n"; + + os << "height_sections: ["; + for (auto §ion : rpc_ctx.height_sections) { + os << section << ", "; + } + os << "]\n"; + os << "}"; + return os; +} + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 3bcfc532e86..a1ef1af39ff 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -2,9 +2,9 @@ include(operators) set(DISTRIBUTE_DEPS "") if(WITH_GRPC) - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) else() - set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node) if(WITH_BRPC_RDMA) find_library(IBVERBS_LIBRARY NAMES ibverbs) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) diff --git a/paddle/fluid/operators/distributed_ops/send_op.cc b/paddle/fluid/operators/distributed_ops/send_op.cc index 8b09cf86d7d..347395b7ccd 100644 --- a/paddle/fluid/operators/distributed_ops/send_op.cc +++ b/paddle/fluid/operators/distributed_ops/send_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/distributed/communicator.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed/rpc_common.h" @@ -47,10 +48,12 @@ class SendOp : public framework::OperatorBase { if (send_varnames.size() > 0) { PADDLE_ENFORCE_EQ(ins.size(), 1, ""); - auto send_functor = distributed::ParameterSend(); - auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, - height_sections); - send_functor(rpc_ctx, scope, static_cast(sync_send)); + // auto send_functor = distributed::ParameterSend(); + // auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, + // epmap, + // height_sections); + // send_functor(rpc_ctx, scope, static_cast(sync_send)); + distributed::Communicator::GetInstance()->Send(ins[0], scope); } else { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); -- GitLab