未验证 提交 f3f52fc1 编写于 作者: G gongweibao 提交者: GitHub

Retry when failed to bind address. (#20642)

上级 3e831b60
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <limits>
#include <memory>
#include <string>
......@@ -22,6 +23,7 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter;
DECLARE_bool(rpc_disable_reuse_port);
DECLARE_int32(rpc_retry_bind_port);
namespace paddle {
namespace operators {
......@@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
};
void AsyncGRPCServer::StartServer() {
::grpc::ServerBuilder builder;
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
&selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
if (FLAGS_rpc_disable_reuse_port) {
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
}
builder.RegisterService(&service_);
for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) {
::grpc::ServerBuilder builder;
std::unique_ptr<GrpcService::AsyncService> service(
new GrpcService::AsyncService());
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
&selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
if (FLAGS_rpc_disable_reuse_port) {
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
}
builder.RegisterService(service.get());
for (auto t : rpc_call_map_) {
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
}
server_ = builder.BuildAndStart();
if (selected_port_ != 0) {
LOG(INFO) << "Server listening on " << bind_address_
<< " successful, selected port: " << selected_port_;
service_.reset(service.release());
break;
}
LOG(WARNING) << "Server listening on " << bind_address_
<< " failed, selected port: " << selected_port_
<< ", retry after 3 seconds!";
for (auto t : rpc_call_map_) {
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
sleep(3);
}
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << bind_address_
<< " selected port: " << selected_port_;
PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s",
bind_address_);
std::function<void(const std::string&, int)> f =
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
......@@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
RequestBase* b = nullptr;
if (rpc_name == kRequestSend) {
b = new RequestSend(&service_, cq.get(), handler, req_id);
b = new RequestSend(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
b = new RequestGet(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id,
this);
} else if (rpc_name == kRequestGetMonomerBarrier) {
b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id,
b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id,
this);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) {
b = new RequestNotify(&service_, cq.get(), handler, req_id);
b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <thread> // NOLINT
......@@ -67,7 +68,7 @@ class AsyncGRPCServer final : public RPCServer {
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
GrpcService::AsyncService service_;
std::unique_ptr<GrpcService::AsyncService> service_;
std::unique_ptr<::grpc::Server> server_;
// condition of the sub program
......
......@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not.");
DEFINE_int32(rpc_retry_bind_port, 3,
"Retry to bind the address if address is already used.");
namespace paddle {
namespace operators {
......
......@@ -192,6 +192,7 @@ def __bootstrap__():
read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port')
read_env_flags.append('rpc_retry_bind_port')
read_env_flags.append('worker_update_interval_secs')
......
......@@ -846,6 +846,7 @@ class TestDistBase(unittest.TestCase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"FLAGS_rpc_retry_bind_port": "50",
"FLAGS_cudnn_deterministic": "1",
"http_proxy": "",
"NCCL_P2P_DISABLE": "1",
......
......@@ -105,7 +105,7 @@ def gen_complete_file_flag(flag_file):
class TestListenAndServOp(unittest.TestCase):
def setUp(self):
self.ps_timeout = 5
self.ps_timeout = 200
self.ip = "127.0.0.1"
self.port = "0"
self.trainers = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册