提交 a1a401eb 编写于 作者: F fengjiayi

fix

上级 d11b8e56
......@@ -113,7 +113,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
op->SetAttrMap(attrs);
}
void StartServerNet(bool is_sparse) {
void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
f::Scope scope;
p::CPUPlace place;
if (is_sparse) {
......@@ -121,7 +121,6 @@ void StartServerNet(bool is_sparse) {
} else {
InitTensorsInScope(place, &scope);
}
// sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program;
const auto &root_block = program.Block(0);
......@@ -129,7 +128,6 @@ void StartServerNet(bool is_sparse) {
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")});
attrs.insert({"Fanin", 1});
......@@ -141,12 +139,16 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"sync_mode", true});
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
*initialized = true;
listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit";
}
TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false);
std::atomic<bool> initialized{false};
std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start
// local net
f::Scope scope;
......@@ -156,9 +158,11 @@ TEST(SendRecvOp, CPUDense) {
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get())
->GetSelectedPort();
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
......@@ -184,8 +188,12 @@ TEST(SendRecvOp, CPUDense) {
}
TEST(SendRecvOp, CPUSparse) {
std::thread server_thread(StartServerNet, true);
sleep(3); // wait server to start
std::atomic<bool> initialized;
initialized = false;
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start
// local net
f::Scope scope;
p::CPUPlace place;
......@@ -193,9 +201,11 @@ TEST(SendRecvOp, CPUSparse) {
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get())
->GetSelectedPort();
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册