未验证 提交 0250e808 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8586 from Yancey1989/fix_dist_unittest

Fix send_recv unit test
...@@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
rpc_service_->SetCond(1);
break;
} }
try { try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
......
...@@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
auto var = block->Var(v); auto var = block->Var(v);
var->SetDataType(f::proto::DataType::FP32); var->SetDataType(f::proto::VarType::FP32);
} }
} }
...@@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) { ...@@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) {
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0); f::BlockDesc *optimize_block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape. // X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", block}); attrs.insert({"OptimizeBlock", optimize_block});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false); std::thread server_thread(StartServerNet, false);
sleep(10); // wait server to start sleep(5); // wait server to start
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
InitTensorsInScope(scope, place); InitTensorsInScope(scope, place);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, auto send_op = f::OpRegistry::CreateOp(
{{"Out", {"Out"}}}, attrs); "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
...@@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) {
p::CPUPlace place; p::CPUPlace place;
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(scope, place);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, auto send_op = f::OpRegistry::CreateOp(
{{"Out", {"Out"}}}, attrs); "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>(); auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册