提交 fbd186bd 编写于 作者: Q Qiao Longfei

complete recv op

上级 a7152613
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -48,6 +49,18 @@ class RecvOp : public framework::OperatorBase { ...@@ -48,6 +49,18 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(place);
auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx);
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(outs[0], recv_varnames, epmap, exe_ctx, scope);
} else {
if (with_barrier) { if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
...@@ -76,6 +89,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -76,6 +89,7 @@ class RecvOp : public framework::OperatorBase {
} }
} }
} }
}
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -519,12 +519,20 @@ class DistributeTranspiler(object): ...@@ -519,12 +519,20 @@ class DistributeTranspiler(object):
param_varname, height_sections, eps, table_names) param_varname, height_sections, eps, table_names)
else: else:
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
recv_varnames = []
if self.config.runtime_split_send_recv:
orig_param = program.global_block().vars[param_varname]
recv_varnames = [var.name for var in splited_vars]
splited_var = [orig_param]
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
"recv_varnames": recv_varnames,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
...@@ -549,6 +557,7 @@ class DistributeTranspiler(object): ...@@ -549,6 +557,7 @@ class DistributeTranspiler(object):
continue continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
if not self.config.runtime_split_send_recv:
program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册