提交 fbd186bd 编写于 作者: Q Qiao Longfei

complete recv op

上级 a7152613
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -48,32 +49,45 @@ class RecvOp : public framework::OperatorBase { ...@@ -48,32 +49,45 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
if (with_barrier) { std::vector<std::string> recv_varnames =
std::vector<distributed::VarHandlePtr> rets; Attr<std::vector<std::string>>("recv_varnames");
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; if (recv_varnames.size() > 0) {
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
<< varname << " and with AsyncGetVar"; platform::DeviceContextPool &pool =
rets.push_back( platform::DeviceContextPool::Instance();
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); auto *dev_ctx = pool.Get(place);
} auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx);
if (sync_mode) { auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(outs[0], recv_varnames, epmap, exe_ctx, scope);
} else {
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
} }
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -519,12 +519,20 @@ class DistributeTranspiler(object): ...@@ -519,12 +519,20 @@ class DistributeTranspiler(object):
param_varname, height_sections, eps, table_names) param_varname, height_sections, eps, table_names)
else: else:
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
recv_varnames = []
if self.config.runtime_split_send_recv:
orig_param = program.global_block().vars[param_varname]
recv_varnames = [var.name for var in splited_vars]
splited_var = [orig_param]
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
"recv_varnames": recv_varnames,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
...@@ -549,14 +557,15 @@ class DistributeTranspiler(object): ...@@ -549,14 +557,15 @@ class DistributeTranspiler(object):
continue continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
program.global_block().append_op( if not self.config.runtime_split_send_recv:
type="concat", program.global_block().append_op(
inputs={"X": splited_var}, type="concat",
outputs={"Out": [orig_param]}, inputs={"X": splited_var},
attrs={ outputs={"Out": [orig_param]},
"axis": 0, attrs={
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE "axis": 0,
}) RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册