提交 1edc0423 编写于 作者: Q Qiao Longfei

update send_op

上级 74040cb4
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -37,30 +38,46 @@ class SendOp : public framework::OperatorBase { ...@@ -37,30 +38,46 @@ class SendOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode"); int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto& ctx = *pool.Get(place); auto height_sections = Attr<std::vector<int64_t>>("height_sections");
distributed::RPCClient* rpc_client = if (send_varnames.size() > 0) {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( PADDLE_ENFORCE_EQ(ins.size(), 1, "");
Attr<int>("trainer_id")); framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool =
std::vector<distributed::VarHandlePtr> rets; platform::DeviceContextPool::Instance();
for (size_t i = 0; i < ins.size(); i++) { auto* dev_ctx = pool.Get(place);
if (NeedSend(scope, ins[i])) { auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx);
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; distributed::send<float>(ins[0], send_varnames, epmap, height_sections,
rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); exe_ctx, scope, static_cast<bool>(sync_send));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
} }
} if (sync_send) {
if (sync_send) { for (size_t i = 0; i < rets.size(); i++) {
for (size_t i = 0; i < rets.size(); i++) { VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; }
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册