From 4f1381eac3708ce92b07a01f6cfc9d4131c996af Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 8 Dec 2017 16:20:09 +0800 Subject: [PATCH] recv_op use serialized program --- paddle/operators/recv_op.cc | 11 +++++++---- paddle/operators/send_recv_op_test.cc | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index c69e416e10f..45222f6b768 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -72,8 +72,10 @@ class RecvOp : public framework::OperatorBase { // FIXME(typhoonzero): do not copy framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor); - auto *block = Attr("OptimizeBlock"); - auto *program = block->Program(); + std::string program_str = Attr("OptimizeProgram"); + framework::Program program_desc; + program_desc.ParseFromString(program_str); + framework::ProgramDescBind program(program_desc); framework::Executor executor(dev_ctx); // Run sub graph to get optimized tensor executor.Run(*program, &recv_scope, block->ID(), @@ -108,8 +110,9 @@ This operator will recv tensor from send_op "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr("OptimizeBlock", "type BlockDescBind*", - "optimize network run in server"); + AddAttr( + "OptimizeProgram", "type string", + "Serialized ProgramDesc string for recv to run."); } }; diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index ac03eb3752e..c35dc8fa508 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -85,7 +85,7 @@ void StartServerNet() { paddle::framework::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); - attrs.insert({"OptimizeBlock", block}); + attrs.insert({"OptimizeProgram", program.Proto()->SerializeToString()}); recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"RX"}}}, {{"Out", {"Out"}}}, attrs); paddle::platform::CPUDeviceContext ctx(place); -- GitLab