未验证 提交 f650429b 编写于 作者: 武毅 提交者: GitHub

Merge pull request #6419 from typhoonzero/recv_op_use_ser_prog

recv_op use serialized program
...@@ -72,11 +72,13 @@ class RecvOp : public framework::OperatorBase { ...@@ -72,11 +72,13 @@ class RecvOp : public framework::OperatorBase {
// FIXME(typhoonzero): do not copy // FIXME(typhoonzero): do not copy
framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor); framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor);
auto *block = Attr<framework::BlockDescBind *>("OptimizeBlock"); std::string program_str = Attr<std::string>("OptimizeProgram");
auto *program = block->Program(); framework::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor // Run sub graph to get optimized tensor
executor.Run(*program, &recv_scope, block->ID(), executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/); false /*create_local_scope*/);
auto *out_var = recv_scope.FindVar("Out"); auto *out_var = recv_scope.FindVar("Out");
...@@ -108,8 +110,8 @@ This operator will recv tensor from send_op ...@@ -108,8 +110,8 @@ This operator will recv tensor from send_op
"IP address to listen on.") "IP address to listen on.")
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<framework::BlockDescBind *>("OptimizeBlock", "type BlockDescBind*", AddAttr<std::string>("OptimizeProgram", "type string",
"optimize network run in server"); "Serialized ProgramDesc string for recv to run.");
} }
}; };
......
...@@ -85,7 +85,10 @@ void StartServerNet() { ...@@ -85,7 +85,10 @@ void StartServerNet() {
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"OptimizeBlock", block}); std::string program_proto;
PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto));
attrs.insert({"OptimizeProgram", program_proto});
recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"RX"}}}, recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"RX"}}},
{{"Out", {"Out"}}}, attrs); {{"Out", {"Out"}}}, attrs);
paddle::platform::CPUDeviceContext ctx(place); paddle::platform::CPUDeviceContext ctx(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册