未验证 提交 f3f52fc1 编写于 作者: G gongweibao 提交者: GitHub

Retry when failed to bind address. (#20642)

上级 3e831b60
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <unistd.h>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -22,6 +23,7 @@ limitations under the License. */ ...@@ -22,6 +23,7 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
DECLARE_bool(rpc_disable_reuse_port); DECLARE_bool(rpc_disable_reuse_port);
DECLARE_int32(rpc_retry_bind_port);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { ...@@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
}; };
void AsyncGRPCServer::StartServer() { void AsyncGRPCServer::StartServer() {
::grpc::ServerBuilder builder; for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) {
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), ::grpc::ServerBuilder builder;
&selected_port_); std::unique_ptr<GrpcService::AsyncService> service(
new GrpcService::AsyncService());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); &selected_port_);
if (FLAGS_rpc_disable_reuse_port) {
builder.SetOption( builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
} if (FLAGS_rpc_disable_reuse_port) {
builder.RegisterService(&service_); builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
}
builder.RegisterService(service.get());
for (auto t : rpc_call_map_) {
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
}
server_ = builder.BuildAndStart();
if (selected_port_ != 0) {
LOG(INFO) << "Server listening on " << bind_address_
<< " successful, selected port: " << selected_port_;
service_.reset(service.release());
break;
}
LOG(WARNING) << "Server listening on " << bind_address_
<< " failed, selected port: " << selected_port_
<< ", retry after 3 seconds!";
for (auto t : rpc_call_map_) { sleep(3);
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
} }
server_ = builder.BuildAndStart(); PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s",
LOG(INFO) << "Server listening on " << bind_address_ bind_address_);
<< " selected port: " << selected_port_;
std::function<void(const std::string&, int)> f = std::function<void(const std::string&, int)> f =
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
...@@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
RequestBase* b = nullptr; RequestBase* b = nullptr;
if (rpc_name == kRequestSend) { if (rpc_name == kRequestSend) {
b = new RequestSend(&service_, cq.get(), handler, req_id); b = new RequestSend(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) { } else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id); b = new RequestGet(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) { } else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id); b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) { } else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id, b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id,
this); this);
} else if (rpc_name == kRequestGetMonomerBarrier) { } else if (rpc_name == kRequestGetMonomerBarrier) {
b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id, b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id,
this); this);
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) { } else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) { } else if (rpc_name == kRequestNotify) {
b = new RequestNotify(&service_, cq.get(), handler, req_id); b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -67,7 +68,7 @@ class AsyncGRPCServer final : public RPCServer { ...@@ -67,7 +68,7 @@ class AsyncGRPCServer final : public RPCServer {
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
GrpcService::AsyncService service_; std::unique_ptr<GrpcService::AsyncService> service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
// condition of the sub program // condition of the sub program
......
...@@ -24,6 +24,8 @@ limitations under the License. */ ...@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not."); DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not.");
DEFINE_int32(rpc_retry_bind_port, 3,
"Retry to bind the address if address is already used.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -192,6 +192,7 @@ def __bootstrap__(): ...@@ -192,6 +192,7 @@ def __bootstrap__():
read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port') read_env_flags.append('rpc_disable_reuse_port')
read_env_flags.append('rpc_retry_bind_port')
read_env_flags.append('worker_update_interval_secs') read_env_flags.append('worker_update_interval_secs')
......
...@@ -846,6 +846,7 @@ class TestDistBase(unittest.TestCase): ...@@ -846,6 +846,7 @@ class TestDistBase(unittest.TestCase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast "FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"FLAGS_rpc_retry_bind_port": "50",
"FLAGS_cudnn_deterministic": "1", "FLAGS_cudnn_deterministic": "1",
"http_proxy": "", "http_proxy": "",
"NCCL_P2P_DISABLE": "1", "NCCL_P2P_DISABLE": "1",
......
...@@ -105,7 +105,7 @@ def gen_complete_file_flag(flag_file): ...@@ -105,7 +105,7 @@ def gen_complete_file_flag(flag_file):
class TestListenAndServOp(unittest.TestCase): class TestListenAndServOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 200
self.ip = "127.0.0.1" self.ip = "127.0.0.1"
self.port = "0" self.port = "0"
self.trainers = 1 self.trainers = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册