提交 fbd186bd 编写于 作者: Q Qiao Longfei

complete recv op

上级 a7152613
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -48,6 +49,18 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(place);
auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx);
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(outs[0], recv_varnames, epmap, exe_ctx, scope);
} else {
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
......@@ -76,6 +89,7 @@ class RecvOp : public framework::OperatorBase {
}
}
}
}
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -519,12 +519,20 @@ class DistributeTranspiler(object):
param_varname, height_sections, eps, table_names)
else:
all_recv_outputs.extend(splited_var)
recv_varnames = []
if self.config.runtime_split_send_recv:
orig_param = program.global_block().vars[param_varname]
recv_varnames = [var.name for var in splited_vars]
splited_var = [orig_param]
program.global_block().append_op(
type="recv",
inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
"recv_varnames": recv_varnames,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
......@@ -549,6 +557,7 @@ class DistributeTranspiler(object):
continue
orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections:
if not self.config.runtime_split_send_recv:
program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册