提交 255b36da 编写于 作者: Q Qiao Longfei

can run

上级 8c38aca9
...@@ -59,6 +59,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -59,6 +59,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
send_varname_to_ctx[send_var_name] = send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames, operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section); epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
auto recv_var_name = node->Op()->Input("X")[0]; auto recv_var_name = node->Op()->Input("X")[0];
auto recv_varnames = boost::get<std::vector<std::string>>( auto recv_varnames = boost::get<std::vector<std::string>>(
...@@ -68,13 +70,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -68,13 +70,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
recv_varname_to_ctx[recv_var_name] = recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames, operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {}); epmap, {});
graphs[i]->RemoveNode(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} }
} }
} }
} }
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use ";
operators::distributed::Communicator::Init(send_varname_to_ctx, operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope); recv_varname_to_ctx, scope);
}
} }
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
...@@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
for (auto *scope : local_scopes_) { for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope); NewTempScopeAndInitVars(var_infos_, scope);
} }
ProcessGraph(graphs_, local_scopes_[0]);
} }
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() { void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
......
...@@ -30,7 +30,7 @@ if(WITH_GRPC) ...@@ -30,7 +30,7 @@ if(WITH_GRPC)
else() else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
......
...@@ -63,6 +63,9 @@ static inline void MergeVars(const std::string &var_name, ...@@ -63,6 +63,9 @@ static inline void MergeVars(const std::string &var_name,
} }
} }
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;
void Communicator::SendThread() { void Communicator::SendThread() {
while (running_) { while (running_) {
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
...@@ -117,6 +120,7 @@ void Communicator::RecvThread() { ...@@ -117,6 +120,7 @@ void Communicator::RecvThread() {
void Communicator::Send(const std::string &var_name, void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name // push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name); auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
...@@ -125,6 +129,8 @@ void Communicator::Send(const std::string &var_name, ...@@ -125,6 +129,8 @@ void Communicator::Send(const std::string &var_name,
send_varname_to_queue_[var_name]->Push(tmp_grad_var); send_varname_to_queue_[var_name]->Push(tmp_grad_var);
} }
Communicator *Communicator::GetInstance() { return communicator_.get(); }
void Communicator::Start() { void Communicator::Start() {
running_ = true; running_ = true;
// start send and recv thread // start send and recv thread
......
...@@ -144,7 +144,7 @@ class Communicator { ...@@ -144,7 +144,7 @@ class Communicator {
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope); InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
} }
static Communicator* GetInstance() { return communicator_.get(); } static Communicator* GetInstance();
private: private:
// Init is called by GetInstance. // Init is called by GetInstance.
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -22,15 +23,17 @@ namespace operators { ...@@ -22,15 +23,17 @@ namespace operators {
namespace distributed { namespace distributed {
struct RpcContext { struct RpcContext {
RpcContext(const std::string& name, const std::vector<std::string>& names, RpcContext() = default;
const std::vector<std::string>& emap,
const std::vector<int64_t>& sections) RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections)
: var_name(name), : var_name(name),
splited_var_names(names), splited_var_names(names),
epmap(emap), epmap(emap),
height_sections(sections) {} height_sections(sections) {}
RpcContext(const RpcContext& ctx) { RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names; splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap; epmap = ctx.epmap;
...@@ -43,6 +46,31 @@ struct RpcContext { ...@@ -43,6 +46,31 @@ struct RpcContext {
std::vector<int64_t> height_sections; std::vector<int64_t> height_sections;
}; };
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
os << "splited_var_names: [";
for (auto &name : rpc_ctx.splited_var_names) {
os << name << ", ";
}
os << "]\n";
os << "epmap: [";
for (auto &ep : rpc_ctx.epmap) {
os << ep << ", ";
}
os << "]\n";
os << "height_sections: [";
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
}
os << "]\n";
os << "}";
return os;
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/rpc_common.h"
...@@ -47,10 +48,12 @@ class SendOp : public framework::OperatorBase { ...@@ -47,10 +48,12 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, ""); PADDLE_ENFORCE_EQ(ins.size(), 1, "");
auto send_functor = distributed::ParameterSend<float>(); // auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, // auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames,
height_sections); // epmap,
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send)); // height_sections);
// send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
distributed::Communicator::GetInstance()->Send(ins[0], scope);
} else { } else {
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册