提交 c2cce6ba 编写于 作者: Q Qiao Longfei

simplify parameter send and recv

上级 3c6b733d
...@@ -74,7 +74,7 @@ void Communicator::SendThread() { ...@@ -74,7 +74,7 @@ void Communicator::SendThread() {
merged_var_num++; merged_var_num++;
} }
MergeVars(var_name, vars, send_scope_.get()); MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>(); // auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx, // send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
// send_scope_, true); // send_scope_, true);
} }
...@@ -85,7 +85,7 @@ void Communicator::RecvThread() { ...@@ -85,7 +85,7 @@ void Communicator::RecvThread() {
for (auto &iter : recv_varname_to_ctx_) { for (auto &iter : recv_varname_to_ctx_) {
auto &var_name = iter.first; auto &var_name = iter.first;
VLOG(3) << "recv var " << iter.first; VLOG(3) << "recv var " << iter.first;
auto recv_functor = distributed::ParameterRecv<float>(); // auto recv_functor = distributed::ParameterRecv<float>();
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_); // recv_functor(var_name, iter.second, exe_ctx, recv_scope_);
} }
} }
......
...@@ -54,12 +54,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -54,12 +54,6 @@ class RecvOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("recv_varnames"); Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) { if (recv_varnames.size() > 0) {
framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(place);
auto exe_ctx =
framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr);
auto recv_functor = distributed::ParameterRecv<float>(); auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {}); auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {});
recv_functor(rpc_ctx, scope); recv_functor(rpc_ctx, scope);
......
...@@ -47,12 +47,6 @@ class SendOp : public framework::OperatorBase { ...@@ -47,12 +47,6 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, ""); PADDLE_ENFORCE_EQ(ins.size(), 1, "");
framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto exe_ctx =
framework::ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr);
auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap, auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections); height_sections);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册