提交 7a1d6ae5 编写于 作者: Y Yancey1989

Fix send_recv unit test

上级 fe7c1814
......@@ -129,6 +129,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
if (exit_flag) {
rpc_service_->ShutDown();
rpc_service_->SetCond(1);
break;
}
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
......
......@@ -95,7 +95,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(f::proto::DataType::FP32);
var->SetDataType(f::proto::VarType::FP32);
}
}
......@@ -122,33 +122,37 @@ void StartServerNet(bool is_sparse) {
// sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0);
f::BlockDesc *optimize_block = program.MutableBlock(0);
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block);
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", block});
attrs.insert({"OptimizeBlock", optimize_block});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {}, {}, attrs);
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
listen_and_serv_op->Run(scope, place);
}
TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false);
sleep(10); // wait server to start
sleep(5); // wait server to start
// local net
f::Scope scope;
p::CPUPlace place;
InitTensorsInScope(scope, place);
// create rpc client var
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
{{"Out", {"Out"}}}, attrs);
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place);
auto in_var = scope.Var("x1");
......@@ -175,11 +179,13 @@ TEST(SendRecvOp, CPUSparse) {
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(scope, place);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})});
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
{{"Out", {"Out"}}}, attrs);
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册