提交 be0c4823 编写于 作者: Q Qiao Longfei

update trainer_id

上级 c60f312d
......@@ -60,9 +60,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
epmap, height_section,
trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
......@@ -71,9 +74,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
epmap, {}, trainer_id);
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
......
......@@ -48,7 +48,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
......@@ -112,7 +112,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// FIXME(qiao): use a trick to avoid the bug of recv an selected rows
for (auto i = 1; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[1]);
PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
}
......
......@@ -46,7 +46,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size();
......
......@@ -27,23 +27,26 @@ struct RpcContext {
RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections)
const std::vector<int64_t> &sections, int id)
: var_name(name),
splited_var_names(names),
epmap(emap),
height_sections(sections) {}
height_sections(sections),
trainer_id(id) {}
RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
......
......@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
auto trainer_id = Attr<int>("trainer_id");
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {});
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {},
trainer_id);
recv_functor(rpc_ctx, scope);
} else {
if (with_barrier) {
......
......@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections");
......@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
/*
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
height_sections, trainer_id);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
*/
VLOG(3) << "send " << ins[0];
......@@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册