From d111bfc610cbccc79a44d43f2aca0d12dd1a7025 Mon Sep 17 00:00:00 2001 From: liyin Date: Tue, 2 Apr 2019 19:36:00 +0800 Subject: [PATCH] Fix tile count calculation --- mace/utils/count_down_latch.h | 4 ++++ mace/utils/count_down_latch_test.cc | 2 ++ mace/utils/thread_pool.cc | 18 +++++++++--------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/mace/utils/count_down_latch.h b/mace/utils/count_down_latch.h index 08d60952..a1d79d4f 100644 --- a/mace/utils/count_down_latch.h +++ b/mace/utils/count_down_latch.h @@ -54,6 +54,10 @@ class CountDownLatch { count_.store(count, std::memory_order_release); } + int count() const { + return count_; + } + private: int64_t spin_timeout_; std::atomic count_; diff --git a/mace/utils/count_down_latch_test.cc b/mace/utils/count_down_latch_test.cc index edeed495..9db4baa9 100644 --- a/mace/utils/count_down_latch_test.cc +++ b/mace/utils/count_down_latch_test.cc @@ -39,6 +39,7 @@ TEST_F(CountDownLatchTest, TestWait) { for (int i = 0; i < 10; ++i) { threads[i].join(); } + MACE_CHECK(latch.count() == 0); } TEST_F(CountDownLatchTest, TestSpinWait) { @@ -53,6 +54,7 @@ TEST_F(CountDownLatchTest, TestSpinWait) { for (int i = 0; i < 10; ++i) { threads[i].join(); } + MACE_CHECK(latch.count() == 0); } } // namespace diff --git a/mace/utils/thread_pool.cc b/mace/utils/thread_pool.cc index aaf89e6b..92c128a7 100644 --- a/mace/utils/thread_pool.cc +++ b/mace/utils/thread_pool.cc @@ -134,7 +134,7 @@ ThreadPool::ThreadPool(const size_t thread_count_hint, default_tile_count_ = thread_count; if (cores_to_use.size() >= 2 - && cpu_max_freqs[0] != cpu_max_freqs[cores_to_use.back()]) { + && cpu_max_freqs[cores_to_use[0]] != cpu_max_freqs[cores_to_use.back()]) { default_tile_count_ = thread_count * kTileCountPerThread; } MACE_CHECK(default_tile_count_ > 0, "default tile count should > 0"); @@ -213,8 +213,8 @@ void ThreadPool::Destroy() { if (threads_[i].joinable()) { threads_[i].join(); } else { - LOG(ERROR) << "Thread: " << threads_[i].get_id() << " not joinable" - << std::endl; + VLOG(2) << "Thread: " << threads_[i].get_id() << " not joinable" + << std::endl; } } } @@ -318,7 +318,7 @@ void ThreadPool::Compute1D(const std::function