diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index a4ef70aab6647d4ab81fda187e656c05b87b53e8..adaa5dfd76b341fc677f0611b4d11924f54a266c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. */ using ::grpc::ServerAsyncResponseWriter; DECLARE_bool(rpc_disable_reuse_port); +DECLARE_int32(rpc_retry_bind_port); namespace paddle { namespace operators { @@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { }; void AsyncGRPCServer::StartServer() { - ::grpc::ServerBuilder builder; - builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), - &selected_port_); - - builder.SetMaxSendMessageSize(std::numeric_limits::max()); - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - if (FLAGS_rpc_disable_reuse_port) { - builder.SetOption( - std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); - } - builder.RegisterService(&service_); + for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) { + ::grpc::ServerBuilder builder; + std::unique_ptr service( + new GrpcService::AsyncService()); + builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), + &selected_port_); + + builder.SetMaxSendMessageSize(std::numeric_limits::max()); + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + if (FLAGS_rpc_disable_reuse_port) { + builder.SetOption( + std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); + } + builder.RegisterService(service.get()); + + for (auto t : rpc_call_map_) { + rpc_cq_[t.first].reset(builder.AddCompletionQueue().release()); + } + + server_ = builder.BuildAndStart(); + if (selected_port_ != 0) { + LOG(INFO) << "Server listening on " << bind_address_ + << " successful, selected port: " << selected_port_; + service_.reset(service.release()); + break; + } + + LOG(WARNING) << "Server listening on " << bind_address_ + << " failed, selected port: " << selected_port_ + << ", retry after 3 seconds!"; - for (auto t : rpc_call_map_) { - rpc_cq_[t.first].reset(builder.AddCompletionQueue().release()); + sleep(3); } - server_ = builder.BuildAndStart(); - LOG(INFO) << "Server listening on " << bind_address_ - << " selected port: " << selected_port_; + PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s", + bind_address_); std::function f = std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, @@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, RequestBase* b = nullptr; if (rpc_name == kRequestSend) { - b = new RequestSend(&service_, cq.get(), handler, req_id); + b = new RequestSend(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestGet) { - b = new RequestGet(&service_, cq.get(), handler, req_id); + b = new RequestGet(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestGetNoBarrier) { - b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id); + b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestGetMonomerVariable) { - b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id, + b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id, this); } else if (rpc_name == kRequestGetMonomerBarrier) { - b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id, + b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id, this); } else if (rpc_name == kRequestPrefetch) { - b = new RequestPrefetch(&service_, cq.get(), handler, req_id); + b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestCheckpoint) { - b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); + b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestNotify) { - b = new RequestNotify(&service_, cq.get(), handler, req_id); + b = new RequestNotify(service_.get(), cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); } diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.h b/paddle/fluid/operators/distributed/grpc/grpc_server.h index 2fd3a7a74073b52770158cf47b1c86cedae78291..ee6950205b31d9e2d3cd8722daf1c12117a17029 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include #include // NOLINT @@ -67,7 +68,7 @@ class AsyncGRPCServer final : public RPCServer { std::mutex cq_mutex_; volatile bool is_shut_down_ = false; - GrpcService::AsyncService service_; + std::unique_ptr service_; std::unique_ptr<::grpc::Server> server_; // condition of the sub program diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc index 9bd2c9928ccdb6416976b76e776fb22b28ea1f5d..548277139eb856e2ebd2cac2ef33154e767aa570 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.cc @@ -24,6 +24,8 @@ limitations under the License. */ #include "paddle/fluid/platform/port.h" DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not."); +DEFINE_int32(rpc_retry_bind_port, 3, + "Retry to bind the address if address is already used."); namespace paddle { namespace operators { diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 527a7c45cb1343035bd95f041c1ebc9a30a7edf2..295dae7217299e8258a49aea05130755d615b4bc 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -192,6 +192,7 @@ def __bootstrap__(): read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_disable_reuse_port') + read_env_flags.append('rpc_retry_bind_port') read_env_flags.append('worker_update_interval_secs') diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 4080eebbb51555f96d0d1e444aa26f575733980e..c0febf88a4dcc61730556fdbe9f466370efb7562 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -846,6 +846,7 @@ class TestDistBase(unittest.TestCase): "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_rpc_deadline": "30000", # 5sec to fail fast + "FLAGS_rpc_retry_bind_port": "50", "FLAGS_cudnn_deterministic": "1", "http_proxy": "", "NCCL_P2P_DISABLE": "1", diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py index 07a0ae9a82eb05416f821baaaa4c4a84cc30f6e2..3e63282542a5da3a8388ab48e0dd9899e58f8ad3 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py @@ -105,7 +105,7 @@ def gen_complete_file_flag(flag_file): class TestListenAndServOp(unittest.TestCase): def setUp(self): - self.ps_timeout = 5 + self.ps_timeout = 200 self.ip = "127.0.0.1" self.port = "0" self.trainers = 1