提交 be0c4823 编写于 作者: Q Qiao Longfei

update trainer_id

上级 c60f312d
...@@ -60,9 +60,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -60,9 +60,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("epmap")); node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>( auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections")); node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames, operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section); epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") { } else if (node->Name() == "recv") {
...@@ -71,9 +74,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -71,9 +74,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("recv_varnames")); node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>( auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap")); node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames, operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {}); epmap, {}, trainer_id);
nodes_to_delete.push_back(node); nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: " VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name]; << recv_varname_to_ctx[recv_var_name];
......
...@@ -48,7 +48,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -48,7 +48,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name); auto *recv_var = scope.FindVar(rpc_ctx.var_name);
...@@ -112,7 +112,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -112,7 +112,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// FIXME(qiao): use a trick to avoid the bug of recv an selected rows // FIXME(qiao): use a trick to avoid the bug of recv an selected rows
for (auto i = 1; i < recv_slr.rows().size(); ++i) { for (auto i = 1; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i] + row_offset; auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[1]); PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width, memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width); recv_slr.value().data<T>() + i * width, sizeof(T) * width);
} }
......
...@@ -46,7 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -46,7 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size(); size_t out_num = rpc_ctx.splited_var_names.size();
......
...@@ -27,23 +27,26 @@ struct RpcContext { ...@@ -27,23 +27,26 @@ struct RpcContext {
RpcContext(const std::string &name, const std::vector<std::string> &names, RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap, const std::vector<std::string> &emap,
const std::vector<int64_t> &sections) const std::vector<int64_t> &sections, int id)
: var_name(name), : var_name(name),
splited_var_names(names), splited_var_names(names),
epmap(emap), epmap(emap),
height_sections(sections) {} height_sections(sections),
trainer_id(id) {}
RpcContext(const RpcContext &ctx) { RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names; splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap; epmap = ctx.epmap;
height_sections = ctx.height_sections; height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
} }
std::string var_name; std::string var_name;
std::vector<std::string> splited_var_names; std::vector<std::string> splited_var_names;
std::vector<std::string> epmap; std::vector<std::string> epmap;
std::vector<int64_t> height_sections; std::vector<int64_t> height_sections;
int trainer_id;
}; };
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
......
...@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase { ...@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place); auto &ctx = *pool.Get(place);
auto trainer_id = Attr<int>("trainer_id");
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
Attr<int>("trainer_id"));
std::vector<std::string> recv_varnames = std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames"); Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) { if (recv_varnames.size() > 0) {
auto recv_functor = distributed::ParameterRecv<float>(); auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {},
trainer_id);
recv_functor(rpc_ctx, scope); recv_functor(rpc_ctx, scope);
} else { } else {
if (with_barrier) { if (with_barrier) {
......
...@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase { ...@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
auto epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode"); int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames"); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections"); auto height_sections = Attr<std::vector<int64_t>>("sections");
...@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase { ...@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
/* /*
auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections); height_sections, trainer_id);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send)); send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
*/ */
VLOG(3) << "send " << ins[0]; VLOG(3) << "send " << ins[0];
...@@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase { ...@@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册