提交 8c38aca9 编写于 作者: Q Qiao Longfei

tmp commit

上级 fab1b54d
......@@ -82,7 +82,7 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor communicator)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/communicator.h"
namespace paddle {
namespace framework {
......@@ -39,6 +40,43 @@ inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
}
}
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap;
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
for (auto &node : graphs[i]->Nodes()) {
if (node->IsOp()) {
if (node->Op()->Type() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
} else if (node->Op()->Type() == "recv") {
auto recv_var_name = node->Op()->Input("X")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
}
}
}
}
// init communicator here
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
......
......@@ -87,12 +87,12 @@ class BlockingQueue {
std::condition_variable send_cv_;
};
using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator(
const std::unordered_map<std::string, RpcContext>& send_varname_to_ctx,
const std::unordered_map<std::string, RpcContext>& recv_varname_to_ctx,
Scope* recv_scope)
Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
......@@ -128,14 +128,38 @@ class Communicator {
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unordered_map<std::string, RpcContext> send_varname_to_ctx_;
std::unordered_map<std::string, RpcContext> recv_varname_to_ctx_;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unique_ptr<std::thread> send_thread_;
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
}
static Communicator* GetInstance() { return communicator_.get(); }
private:
// Init is called by GetInstance.
static void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
private:
static std::once_flag init_flag_;
static std::unique_ptr<Communicator> communicator_;
};
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册