From 7a1d6ae5f61e466a403b26d6876463883f0946ae Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 26 Feb 2018 19:15:30 +0800 Subject: [PATCH] Fix send_recv unit test --- paddle/fluid/operators/listen_and_serv_op.cc | 2 ++ paddle/fluid/operators/send_recv_op_test.cc | 26 ++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index ee0e3533ce0..8e9923c87ce 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase { } if (exit_flag) { rpc_service_->ShutDown(); + rpc_service_->SetCond(1); + break; } try { executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index 008c012a32e..e9fb845b475 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, for (auto kv : outputs) { for (auto v : kv.second) { auto var = block->Var(v); - var->SetDataType(f::proto::DataType::FP32); + var->SetDataType(f::proto::VarType::FP32); } } @@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) { // sub program run in listen_and_serv_op, for simple test we use sum f::ProgramDesc program; - f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *optimize_block = program.MutableBlock(0); // X for server side tensors, RX for received tensers, must be of same shape. - AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); + AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); f::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); + attrs.insert({"Fanin", 1}); attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); - attrs.insert({"OptimizeBlock", block}); + attrs.insert({"OptimizeBlock", optimize_block}); listen_and_serv_op = - f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs); + f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); listen_and_serv_op->Run(scope, place); } TEST(SendRecvOp, CPUDense) { std::thread server_thread(StartServerNet, false); - sleep(10); // wait server to start + sleep(5); // wait server to start // local net f::Scope scope; p::CPUPlace place; InitTensorsInScope(scope, place); + // create rpc client var + scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); - auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp( + "send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); send_op->Run(scope, place); auto in_var = scope.Var("x1"); @@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) { p::CPUPlace place; p::CPUDeviceContext ctx(place); InitSelectedRowsInScope(scope, place); + scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); - auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp( + "send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); send_op->Run(scope, place); auto x0 = scope.Var("x0")->GetMutable(); -- GitLab