未验证 提交 88d79dfe 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #10292 from typhoonzero/fix_grpc_server_ready_condition

Fix grpc server ready condition
......@@ -211,6 +211,11 @@ void AsyncGRPCServer::WaitClientGet(int count) {
}
}
void AsyncGRPCServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
void AsyncGRPCServer::RunSyncUpdate() {
::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
......@@ -244,6 +249,12 @@ void AsyncGRPCServer::RunSyncUpdate() {
t_prefetch_.reset(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
// wait server
server_->Wait();
t_send_->join();
......
......@@ -45,8 +45,9 @@ class RequestBase;
class AsyncGRPCServer final {
public:
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode) {}
: address_(address), sync_mode_(sync_mode), ready_(0) {}
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
......@@ -118,6 +119,10 @@ class AsyncGRPCServer final {
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
}; // namespace detail
......
......@@ -66,12 +66,7 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service->GetSelectedPort();
port_file.close();
}
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
......@@ -79,15 +74,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
int ListenAndServOp::GetSelectedPort() const {
return rpc_service_->GetSelectedPort();
}
void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
server_thread_->join();
}
void ListenAndServOp::SavePort(const std::string &file_path) const {
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort();
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
......@@ -318,9 +325,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
std::string file_path =
string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
SavePort(file_path);
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <ostream>
#include <string>
......@@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
int GetSelectedPort() const;
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
......@@ -49,14 +48,25 @@ class ListenAndServOp : public framework::OperatorBase {
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const;
void SavePort(
const std::string& file_path = "/tmp/paddle.selected_port") const;
void WaitServerReady();
int GetSelectedPort() { return selected_port_; }
void Stop() override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
};
} // namespace operators
......
......@@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
f::Scope scope;
p::CPUPlace place;
VLOG(4) << "before init tensor";
if (is_sparse) {
InitSelectedRowsInScope(place, &scope);
} else {
......@@ -137,6 +138,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
VLOG(4) << "before init op";
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
*initialized = true;
......@@ -149,7 +151,9 @@ TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
......@@ -185,6 +189,7 @@ TEST(SendRecvOp, CPUDense) {
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
}
TEST(SendRecvOp, CPUSparse) {
......@@ -193,7 +198,12 @@ TEST(SendRecvOp, CPUSparse) {
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
sleep(5); // wait server to start
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
listen_and_serv_op_ptr->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
......@@ -201,10 +211,6 @@ TEST(SendRecvOp, CPUSparse) {
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
......@@ -236,4 +242,5 @@ TEST(SendRecvOp, CPUSparse) {
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
}
......@@ -34,7 +34,7 @@ class TestSendOp(unittest.TestCase):
p.start()
time.sleep(10)
with open("/tmp/paddle.selected_port", "r") as fn:
with open("/tmp/paddle.%d.selected_port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册