diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index defbd7625d0ca5fdc57f4d7e511a8f10a5452e2f..3041fbe5a844b565d48f013dd562359333703d5b 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -25,16 +25,18 @@ DEFINE_int32(dist_threadpool_size, 0, namespace paddle { namespace framework { - +std::mutex threadpool_mu; std::unique_ptr ThreadPool::threadpool_(nullptr); std::once_flag ThreadPool::init_flag_; ThreadPool* ThreadPool::GetInstance() { + std::lock_guard l(threadpool_mu); std::call_once(init_flag_, &ThreadPool::Init); return threadpool_.get(); } -void ThreadPool::Reset() { +void ThreadPool::TestReset() { + std::lock_guard l(threadpool_mu); threadpool_.reset(nullptr); ThreadPool::Init(); } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 745dcc7c7d4a20ad1c8b4023f360a0ccda076114..1513e35bb53d51e636e44937c181f26a36504e0a 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -56,7 +56,8 @@ class ThreadPool { static ThreadPool* GetInstance(); // delete current thread pool and create a new one. - static void Reset(); + // Only used by test cases to reset the threadpool. + static void TestReset(); ~ThreadPool(); diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 1d55e011c77d07dc6c637bcf5d3650ef5e696bb2..cad45d501a9c4a5e4577b63bd7b37889fe457a77 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -52,6 +52,6 @@ TEST(ThreadPool, ConcurrentRun) { for (auto& t : threads) { t.join(); } - framework::ThreadPool::Reset(); + framework::ThreadPool::TestReset(); EXPECT_EQ(sum, ((n + 1) * n) / 2); }