提交 1edc0423 编写于 作者: Q Qiao Longfei

update send_op

上级 74040cb4
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -37,10 +38,24 @@ class SendOp : public framework::OperatorBase { ...@@ -37,10 +38,24 @@ class SendOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode"); int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("height_sections");
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
framework::RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto exe_ctx = framework::ExecutionContext(*this, scope, *dev_ctx, ctx);
distributed::send<float>(ins[0], send_varnames, epmap, height_sections,
exe_ctx, scope, static_cast<bool>(sync_send));
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
...@@ -51,7 +66,8 @@ class SendOp : public framework::OperatorBase { ...@@ -51,7 +66,8 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
...@@ -64,6 +80,7 @@ class SendOp : public framework::OperatorBase { ...@@ -64,6 +80,7 @@ class SendOp : public framework::OperatorBase {
} }
} }
} }
}
}; };
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册