From c8f74b62153bf7f3de12b7fe4a7afba58ca86a0b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 13 Feb 2020 14:08:08 -0800 Subject: [PATCH] Fix a race condition in tensorflow::ProfilerServer. server_ is accessed concurrently from two threads without any locking. Amongst other things this could lead to a deadlock if a server was started and then immediately destroyed, because the if (server_) ... condition in ~ProfilerServer could run before server_ had been assigned in the worker thread. Rather than add locking, a simpler solution is just to remove the thread altogether. There is no need to use a thread here: we simply need to call Wait() in the destructor, because gRPC has its own threading internally. PiperOrigin-RevId: 294996411 Change-Id: Idd507d1ea97642e540fdec520900a3c491d34e5a --- .../core/profiler/rpc/profiler_server.cc | 28 ++++++++----------- .../core/profiler/rpc/profiler_server.h | 1 - 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc index 477a2490028..2d488916196 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.cc +++ b/tensorflow/core/profiler/rpc/profiler_server.cc @@ -29,21 +29,14 @@ limitations under the License. namespace tensorflow { void ProfilerServer::StartProfilerServer(int32 port) { - Env* env = Env::Default(); - auto start_server = [port, this]() { - string server_address = absl::StrCat("0.0.0.0:", port); - std::unique_ptr service = - CreateProfilerService(); - ::grpc::ServerBuilder builder; - builder.AddListeningPort(server_address, - ::grpc::InsecureServerCredentials()); - builder.RegisterService(service.get()); - server_ = builder.BuildAndStart(); - LOG(INFO) << "Profiling Server listening on " << server_address; - server_->Wait(); - }; - server_thread_ = - WrapUnique(env->StartThread({}, "ProfilerServer", start_server)); + string server_address = absl::StrCat("0.0.0.0:", port); + std::unique_ptr service = + CreateProfilerService(); + ::grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + server_ = builder.BuildAndStart(); + LOG(INFO) << "Profiling Server listening on " << server_address; } void ProfilerServer::MaybeStartProfilerServer() { @@ -67,7 +60,10 @@ void ProfilerServer::MaybeStartProfilerServer() { } ProfilerServer::~ProfilerServer() { - if (server_) server_->Shutdown(); + if (server_) { + server_->Shutdown(); + server_->Wait(); + } } } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h index 26e9606e2c5..81b3a5b7f3b 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.h +++ b/tensorflow/core/profiler/rpc/profiler_server.h @@ -35,7 +35,6 @@ class ProfilerServer { private: std::unique_ptr<::grpc::Server> server_; - std::unique_ptr server_thread_; }; } // namespace tensorflow -- GitLab