提交 d111bfc6 编写于 作者: L liyin

Fix tile count calculation

上级 aef261d1
...@@ -54,6 +54,10 @@ class CountDownLatch { ...@@ -54,6 +54,10 @@ class CountDownLatch {
count_.store(count, std::memory_order_release); count_.store(count, std::memory_order_release);
} }
int count() const {
return count_;
}
private: private:
int64_t spin_timeout_; int64_t spin_timeout_;
std::atomic<int> count_; std::atomic<int> count_;
......
...@@ -39,6 +39,7 @@ TEST_F(CountDownLatchTest, TestWait) { ...@@ -39,6 +39,7 @@ TEST_F(CountDownLatchTest, TestWait) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
threads[i].join(); threads[i].join();
} }
MACE_CHECK(latch.count() == 0);
} }
TEST_F(CountDownLatchTest, TestSpinWait) { TEST_F(CountDownLatchTest, TestSpinWait) {
...@@ -53,6 +54,7 @@ TEST_F(CountDownLatchTest, TestSpinWait) { ...@@ -53,6 +54,7 @@ TEST_F(CountDownLatchTest, TestSpinWait) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
threads[i].join(); threads[i].join();
} }
MACE_CHECK(latch.count() == 0);
} }
} // namespace } // namespace
......
...@@ -134,7 +134,7 @@ ThreadPool::ThreadPool(const size_t thread_count_hint, ...@@ -134,7 +134,7 @@ ThreadPool::ThreadPool(const size_t thread_count_hint,
default_tile_count_ = thread_count; default_tile_count_ = thread_count;
if (cores_to_use.size() >= 2 if (cores_to_use.size() >= 2
&& cpu_max_freqs[0] != cpu_max_freqs[cores_to_use.back()]) { && cpu_max_freqs[cores_to_use[0]] != cpu_max_freqs[cores_to_use.back()]) {
default_tile_count_ = thread_count * kTileCountPerThread; default_tile_count_ = thread_count * kTileCountPerThread;
} }
MACE_CHECK(default_tile_count_ > 0, "default tile count should > 0"); MACE_CHECK(default_tile_count_ > 0, "default tile count should > 0");
...@@ -213,8 +213,8 @@ void ThreadPool::Destroy() { ...@@ -213,8 +213,8 @@ void ThreadPool::Destroy() {
if (threads_[i].joinable()) { if (threads_[i].joinable()) {
threads_[i].join(); threads_[i].join();
} else { } else {
LOG(ERROR) << "Thread: " << threads_[i].get_id() << " not joinable" VLOG(2) << "Thread: " << threads_[i].get_id() << " not joinable"
<< std::endl; << std::endl;
} }
} }
} }
...@@ -318,7 +318,7 @@ void ThreadPool::Compute1D(const std::function<void(size_t, ...@@ -318,7 +318,7 @@ void ThreadPool::Compute1D(const std::function<void(size_t,
} }
size_t step_tile_size = step * tile_size; size_t step_tile_size = step * tile_size;
size_t tile_count = RoundUp(items, tile_size); size_t tile_count = RoundUpDiv(items, tile_size);
Run([&](size_t tile_idx) { Run([&](size_t tile_idx) {
size_t tile_start = start + tile_idx * step_tile_size; size_t tile_start = start + tile_idx * step_tile_size;
size_t tile_end = std::min(end, tile_start + step_tile_size); size_t tile_end = std::min(end, tile_start + step_tile_size);
...@@ -366,8 +366,8 @@ void ThreadPool::Compute2D(const std::function<void(size_t /* start */, ...@@ -366,8 +366,8 @@ void ThreadPool::Compute2D(const std::function<void(size_t /* start */,
size_t step_tile_size0 = step0 * tile_size0; size_t step_tile_size0 = step0 * tile_size0;
size_t step_tile_size1 = step1 * tile_size1; size_t step_tile_size1 = step1 * tile_size1;
size_t tile_count0 = RoundUp(items0, tile_size0); size_t tile_count0 = RoundUpDiv(items0, tile_size0);
size_t tile_count1 = RoundUp(items1, tile_size1); size_t tile_count1 = RoundUpDiv(items1, tile_size1);
Run([&](size_t tile_idx) { Run([&](size_t tile_idx) {
size_t tile_idx0 = tile_idx / tile_count1; size_t tile_idx0 = tile_idx / tile_count1;
...@@ -438,9 +438,9 @@ void ThreadPool::Compute3D(const std::function<void(size_t /* start */, ...@@ -438,9 +438,9 @@ void ThreadPool::Compute3D(const std::function<void(size_t /* start */,
size_t step_tile_size0 = step0 * tile_size0; size_t step_tile_size0 = step0 * tile_size0;
size_t step_tile_size1 = step1 * tile_size1; size_t step_tile_size1 = step1 * tile_size1;
size_t step_tile_size2 = step2 * tile_size2; size_t step_tile_size2 = step2 * tile_size2;
size_t tile_count0 = RoundUp(items0, tile_size0); size_t tile_count0 = RoundUpDiv(items0, tile_size0);
size_t tile_count1 = RoundUp(items1, tile_size1); size_t tile_count1 = RoundUpDiv(items1, tile_size1);
size_t tile_count2 = RoundUp(items2, tile_size2); size_t tile_count2 = RoundUpDiv(items2, tile_size2);
size_t tile_count12 = tile_count1 * tile_count2; size_t tile_count12 = tile_count1 * tile_count2;
Run([&](size_t tile_idx) { Run([&](size_t tile_idx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册