提交 a1a401eb 编写于 作者: F fengjiayi

fix

上级 d11b8e56
...@@ -113,7 +113,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -113,7 +113,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
op->SetAttrMap(attrs); op->SetAttrMap(attrs);
} }
void StartServerNet(bool is_sparse) { void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
if (is_sparse) { if (is_sparse) {
...@@ -121,7 +121,6 @@ void StartServerNet(bool is_sparse) { ...@@ -121,7 +121,6 @@ void StartServerNet(bool is_sparse) {
} else { } else {
InitTensorsInScope(place, &scope); InitTensorsInScope(place, &scope);
} }
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
...@@ -129,7 +128,6 @@ void StartServerNet(bool is_sparse) { ...@@ -129,7 +128,6 @@ void StartServerNet(bool is_sparse) {
auto *prefetch_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape. // X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")}); attrs.insert({"endpoint", std::string("127.0.0.1:0")});
attrs.insert({"Fanin", 1}); attrs.insert({"Fanin", 1});
...@@ -141,12 +139,16 @@ void StartServerNet(bool is_sparse) { ...@@ -141,12 +139,16 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
*initialized = true;
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit"; LOG(INFO) << "server exit";
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false); std::atomic<bool> initialized{false};
std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start sleep(5); // wait server to start
// local net // local net
f::Scope scope; f::Scope scope;
...@@ -156,9 +158,11 @@ TEST(SendRecvOp, CPUDense) { ...@@ -156,9 +158,11 @@ TEST(SendRecvOp, CPUDense) {
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>( auto *listen_and_serv_op_ptr =
listen_and_serv_op.get()) static_cast<paddle::operators::ListenAndServOp *>(
->GetSelectedPort(); listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
...@@ -184,8 +188,12 @@ TEST(SendRecvOp, CPUDense) { ...@@ -184,8 +188,12 @@ TEST(SendRecvOp, CPUDense) {
} }
TEST(SendRecvOp, CPUSparse) { TEST(SendRecvOp, CPUSparse) {
std::thread server_thread(StartServerNet, true); std::atomic<bool> initialized;
sleep(3); // wait server to start initialized = false;
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
...@@ -193,9 +201,11 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -193,9 +201,11 @@ TEST(SendRecvOp, CPUSparse) {
InitSelectedRowsInScope(place, &scope); InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>( auto *listen_and_serv_op_ptr =
listen_and_serv_op.get()) static_cast<paddle::operators::ListenAndServOp *>(
->GetSelectedPort(); listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册