提交 255b36da 编写于 作者: Q Qiao Longfei

can run

上级 8c38aca9
......@@ -59,6 +59,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Op()->Type() == "recv") {
auto recv_var_name = node->Op()->Input("X")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
......@@ -68,13 +70,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
graphs[i]->RemoveNode(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
}
// init communicator here
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use ";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
}
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
......@@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
......
......@@ -30,7 +30,7 @@ if(WITH_GRPC)
else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
......
......@@ -63,6 +63,9 @@ static inline void MergeVars(const std::string &var_name,
}
}
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;
void Communicator::SendThread() {
while (running_) {
std::vector<std::future<void>> task_futures;
......@@ -117,6 +120,7 @@ void Communicator::RecvThread() {
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
......@@ -125,6 +129,8 @@ void Communicator::Send(const std::string &var_name,
send_varname_to_queue_[var_name]->Push(tmp_grad_var);
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
void Communicator::Start() {
running_ = true;
// start send and recv thread
......
......@@ -144,7 +144,7 @@ class Communicator {
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
}
static Communicator* GetInstance() { return communicator_.get(); }
static Communicator* GetInstance();
private:
// Init is called by GetInstance.
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
......@@ -22,15 +23,17 @@ namespace operators {
namespace distributed {
struct RpcContext {
RpcContext(const std::string& name, const std::vector<std::string>& names,
const std::vector<std::string>& emap,
const std::vector<int64_t>& sections)
RpcContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections)
: var_name(name),
splited_var_names(names),
epmap(emap),
height_sections(sections) {}
RpcContext(const RpcContext& ctx) {
RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
......@@ -43,6 +46,31 @@ struct RpcContext {
std::vector<int64_t> height_sections;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
os << "splited_var_names: [";
for (auto &name : rpc_ctx.splited_var_names) {
os << name << ", ";
}
os << "]\n";
os << "epmap: [";
for (auto &ep : rpc_ctx.epmap) {
os << ep << ", ";
}
os << "]\n";
os << "height_sections: [";
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
}
os << "]\n";
os << "}";
return os;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
......@@ -47,10 +48,12 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
// auto send_functor = distributed::ParameterSend<float>();
// auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames,
// epmap,
// height_sections);
// send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
distributed::Communicator::GetInstance()->Send(ins[0], scope);
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册