diff --git a/mace/core/runtime/cpu/cpu_runtime.cc b/mace/core/runtime/cpu/cpu_runtime.cc index e5742ec11aa483c1fc5f01738bad311ef1df0050..ae3689c2f35cceb13d68e7a91b415dfadeb9fc37 100644 --- a/mace/core/runtime/cpu/cpu_runtime.cc +++ b/mace/core/runtime/cpu/cpu_runtime.cc @@ -52,7 +52,7 @@ MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads, const std::vector &cpu_ids, SchedulePolicy schedule_policy) { MaceOpenMPThreadCount = omp_num_threads; - + SchedSetAffinity(cpu_ids); #ifdef MACE_ENABLE_OPENMP VLOG(1) << "Set OpenMP threads number: " << omp_num_threads << ", CPU core IDs: " << MakeString(cpu_ids); diff --git a/mace/core/runtime/cpu/cpu_runtime.h b/mace/core/runtime/cpu/cpu_runtime.h index ab067ebaae698e2296dcee5469c93961f654b628..08584dd91865b33c23b8cdf42e696b43390b14b9 100644 --- a/mace/core/runtime/cpu/cpu_runtime.h +++ b/mace/core/runtime/cpu/cpu_runtime.h @@ -22,6 +22,7 @@ #include "public/gemmlowp.h" #endif // MACE_ENABLE_QUANTIZE +#include "mace/utils/thread_pool.h" #include "mace/utils/macros.h" #include "mace/public/mace.h" #include "mace/utils/logging.h" @@ -37,7 +38,8 @@ class CPURuntime { bool use_gemmlowp) : num_threads_(num_threads), policy_(policy), - gemm_context_(nullptr) { + gemm_context_(nullptr), + thread_pool_(static_cast(num_threads), policy) { #ifdef MACE_ENABLE_QUANTIZE if (use_gemmlowp) { MACE_CHECK_NOTNULL(GetGemmlowpContext()); @@ -48,6 +50,9 @@ class CPURuntime { SetOpenMPThreadsAndAffinityPolicy(num_threads_, policy_, gemm_context_); + // TODO(liyin): After we replace OpenMP to thread_pool, uncomment the + // following line. + // thread_pool_.Init(); } #ifdef MACE_ENABLE_QUANTIZE @@ -79,6 +84,10 @@ class CPURuntime { return gemm_context_ != nullptr; } + utils::ThreadPool &thread_pool() { + return thread_pool_; + } + private: MaceStatus SetOpenMPThreadsAndAffinityPolicy( int omp_num_threads_hint, @@ -88,6 +97,7 @@ class CPURuntime { int num_threads_; CPUAffinityPolicy policy_; void *gemm_context_; + utils::ThreadPool thread_pool_; }; } // namespace mace diff --git a/mace/port/BUILD.bazel b/mace/port/BUILD.bazel index d23633a6a6290f109c1061191bcbf48d81aa2fa9..722d3e700098eef7e2a3db6a97ddd43540244fe3 100644 --- a/mace/port/BUILD.bazel +++ b/mace/port/BUILD.bazel @@ -19,6 +19,8 @@ cc_library( "env.h", "file_system.h", "logger.h", + "port.h", + "port-arch.h", ], deps = [ "//mace/public", diff --git a/mace/port/port-arch.h b/mace/port/port-arch.h new file mode 100644 index 0000000000000000000000000000000000000000..113c55160048bcd4173986d00575b59f40616ef8 --- /dev/null +++ b/mace/port/port-arch.h @@ -0,0 +1,30 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_PORT_PORT_ARCH_H_ +#define MACE_PORT_PORT_ARCH_H_ + +#if defined __APPLE__ +# define MACE_OS_MAC 1 +# if TARGET_OS_IPHONE +# define MACE_OS_IOS 1 +# endif +#elif defined __linux__ +# define MACE_OS_LINUX 1 +# if defined(__ANDROID__) || defined(ANDROID) +# define MACE_OS_LINUX_ANDROID 1 +# endif +#endif + +#endif // MACE_PORT_PORT_ARCH_H_ diff --git a/mace/port/port.h b/mace/port/port.h new file mode 100644 index 0000000000000000000000000000000000000000..7e215c1b5de9f206b43a52b4bbd7433d10f1a3c7 --- /dev/null +++ b/mace/port/port.h @@ -0,0 +1,26 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_PORT_PORT_H_ +#define MACE_PORT_PORT_H_ + +#include "mace/port/port-arch.h" +#include "mace/public/mace.h" +#include "mace/utils/logging.h" + +#if defined(MACE_OS_LINUX_ANDROID) +#define MACE_THREAD_POOL_USE_SPIN 1 +#endif // MACE_OS_LINUX_ANDROID + +#endif // MACE_PORT_PORT_H_ diff --git a/mace/utils/count_down_latch.h b/mace/utils/count_down_latch.h new file mode 100644 index 0000000000000000000000000000000000000000..08d6095248275bd264ecec4fe7380c4ec818e324 --- /dev/null +++ b/mace/utils/count_down_latch.h @@ -0,0 +1,68 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_UTILS_COUNT_DOWN_LATCH_H_ +#define MACE_UTILS_COUNT_DOWN_LATCH_H_ + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "mace/utils/spinlock.h" + +namespace mace { +namespace utils { + +class CountDownLatch { + public: + explicit CountDownLatch(int64_t spin_timeout) + : spin_timeout_(spin_timeout), count_(0) {} + CountDownLatch(int64_t spin_timeout, int count) + : spin_timeout_(spin_timeout), count_(count) {} + + void Wait() { + if (spin_timeout_ > 0) { + SpinWaitUntil(count_, 0, spin_timeout_); + } + if (count_.load(std::memory_order_acquire) != 0) { + std::unique_lock m(mutex_); + while (count_.load(std::memory_order_acquire) != 0) { + cond_.wait(m); + } + } + } + + void CountDown() { + if (count_.fetch_sub(1, std::memory_order_release) == 1) { + std::unique_lock m(mutex_); + cond_.notify_all(); + } + } + + void Reset(int count) { + count_.store(count, std::memory_order_release); + } + + private: + int64_t spin_timeout_; + std::atomic count_; + std::mutex mutex_; + std::condition_variable cond_; +}; + +} // namespace utils +} // namespace mace + +#endif // MACE_UTILS_COUNT_DOWN_LATCH_H_ + diff --git a/mace/utils/count_down_latch_test.cc b/mace/utils/count_down_latch_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..edeed4955fea87ec90e46725c2f7553c5cbabb02 --- /dev/null +++ b/mace/utils/count_down_latch_test.cc @@ -0,0 +1,61 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mace/utils/count_down_latch.h" + +namespace mace { +namespace utils { + +namespace { + +class CountDownLatchTest : public ::testing::Test { +}; + +TEST_F(CountDownLatchTest, TestWait) { + CountDownLatch latch(0, 10); + std::vector threads(10); + for (int i = 0; i < 10; ++i) { + threads[i] = std::thread([&latch]() { + latch.CountDown(); + }); + } + + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +TEST_F(CountDownLatchTest, TestSpinWait) { + CountDownLatch latch(100, 10); + std::vector threads(10); + for (int i = 0; i < 10; ++i) { + threads[i] = std::thread([&latch]() { + latch.CountDown(); + }); + } + + for (int i = 0; i < 10; ++i) { + threads[i].join(); + } +} + +} // namespace + +} // namespace utils +} // namespace mace diff --git a/mace/utils/spinlock.h b/mace/utils/spinlock.h new file mode 100644 index 0000000000000000000000000000000000000000..6ef150febab49d2baf5f9e59c04fa52b9a46651f --- /dev/null +++ b/mace/utils/spinlock.h @@ -0,0 +1,65 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_UTILS_SPINLOCK_H_ +#define MACE_UTILS_SPINLOCK_H_ + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include "mace/port/port.h" +#include "mace/port/env.h" +#include "mace/utils/logging.h" + +namespace mace { +namespace utils { + +inline void SpinWait(const std::atomic &variable, + const int value, + const int64_t spin_wait_max_time = -1) { + auto start_time = std::chrono::high_resolution_clock::now(); + for (size_t k = 1; variable.load(std::memory_order_acquire) == value; ++k) { + if (spin_wait_max_time > 0 && k % 1000 == 0) { + auto end_time = std::chrono::high_resolution_clock::now(); + int64_t elapse = + std::chrono::duration_cast( + end_time - start_time).count(); + if (elapse > spin_wait_max_time) { + break; + } + } + } +} + +inline void SpinWaitUntil(const std::atomic &variable, + const int value, + const int64_t spin_wait_max_time = -1) { + auto start_time = std::chrono::high_resolution_clock::now(); + for (size_t k = 1; variable.load(std::memory_order_acquire) != value; ++k) { + if (spin_wait_max_time > 0 && k % 1000 == 0) { + auto end_time = std::chrono::high_resolution_clock::now(); + int64_t elapse = + std::chrono::duration_cast( + end_time - start_time).count(); + if (elapse > spin_wait_max_time) { + break; + } + } + } +} + +} // namespace utils +} // namespace mace + +#endif // MACE_UTILS_SPINLOCK_H_ diff --git a/mace/utils/spinlock_test.cc b/mace/utils/spinlock_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..40f49f4afeb022cc9e69139e3dcb62169af7e241 --- /dev/null +++ b/mace/utils/spinlock_test.cc @@ -0,0 +1,44 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mace/utils/spinlock.h" + +namespace mace { +namespace utils { + +namespace { + +class SpinTest : public ::testing::Test { +}; + +TEST_F(SpinTest, TestWait) { + std::atomic lock(1); + + std::thread t([&lock]() { + lock = 0; + }); + + SpinWait(lock, 1, std::numeric_limits::max()); + + t.join(); +} + +} // namespace + +} // namespace utils +} // namespace mace diff --git a/mace/utils/thread_pool.cc b/mace/utils/thread_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaf89e6b9c9b1e016bdf6accba50c8186a3ad1d3 --- /dev/null +++ b/mace/utils/thread_pool.cc @@ -0,0 +1,470 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "mace/port/port.h" +#include "mace/port/env.h" +#include "mace/utils/logging.h" +#include "mace/utils/spinlock.h" +#include "mace/utils/math.h" +#include "mace/utils/thread_pool.h" + +namespace mace { +namespace utils { + +constexpr int kThreadPoolSpinWaitTime = 2000000; // ns +constexpr int kTileCountPerThread = 2; +constexpr int kMaxCostUsingSingleThread = 100; + +namespace { + +enum { + kThreadPoolNone = 0, + kThreadPoolInit = 1, + kThreadPoolRun = 2, + kThreadPoolShutdown = 4, + kThreadPoolEventMask = 0x7fffffff +}; + +struct CPUFreq { + size_t core_id; + float freq; +}; + +void GetCPUCoresToUse(const std::vector &cpu_max_freqs, + const CPUAffinityPolicy policy, + const size_t thread_count_hint, + std::vector *cores) { + size_t thread_count = thread_count_hint; + if (!cpu_max_freqs.empty()) { + const size_t cpu_count = cpu_max_freqs.size(); + if (thread_count == 0 || thread_count > cpu_count) { + thread_count = cpu_count; + } + + if (policy != CPUAffinityPolicy::AFFINITY_NONE) { + std::vector cpu_freq(cpu_max_freqs.size()); + for (size_t i = 0; i < cpu_max_freqs.size(); ++i) { + cpu_freq[i].core_id = i; + cpu_freq[i].freq = cpu_max_freqs[i]; + } + if (policy == CPUAffinityPolicy::AFFINITY_POWER_SAVE || + policy == CPUAffinityPolicy::AFFINITY_LITTLE_ONLY) { + std::sort(cpu_freq.begin(), + cpu_freq.end(), + [=](const CPUFreq &lhs, const CPUFreq &rhs) { + return lhs.freq < rhs.freq; + }); + } else if (policy == CPUAffinityPolicy::AFFINITY_HIGH_PERFORMANCE || + policy == CPUAffinityPolicy::AFFINITY_BIG_ONLY) { + std::sort(cpu_freq.begin(), + cpu_freq.end(), + [](const CPUFreq &lhs, const CPUFreq &rhs) { + return lhs.freq > rhs.freq; + }); + } + + // decide num of cores to use + size_t cores_to_use = 0; + if (policy == CPUAffinityPolicy::AFFINITY_BIG_ONLY + || policy == CPUAffinityPolicy::AFFINITY_LITTLE_ONLY) { + for (size_t i = 0; i < cpu_max_freqs.size(); ++i) { + if (cpu_freq[i].freq != cpu_freq[0].freq) { + break; + } + ++cores_to_use; + } + } else { + cores_to_use = thread_count; + } + MACE_CHECK(cores_to_use > 0, "number of cores to use should > 0"); + cores->resize(cores_to_use); + for (size_t i = 0; i < cores_to_use; ++i) { + VLOG(2) << "Bind thread to core: " << cpu_freq[i].core_id + << " with freq " + << cpu_freq[i].freq; + (*cores)[i] = static_cast(cpu_freq[i].core_id); + } + } + } else { + LOG(ERROR) << "CPU core is empty"; + } +} + +} // namespace + +ThreadPool::ThreadPool(const size_t thread_count_hint, + const CPUAffinityPolicy policy) + : event_(kThreadPoolNone), + count_down_latch_(kThreadPoolSpinWaitTime) { + size_t thread_count = thread_count_hint; + + std::vector cpu_max_freqs; + if (port::Env::Default()->GetCPUMaxFreq(&cpu_max_freqs) + != MaceStatus::MACE_SUCCESS) { + LOG(ERROR) << "Fail to get cpu max frequencies"; + } + + thread_count = std::max(static_cast(1), + std::min(thread_count, cpu_max_freqs.size())); + + std::vector cores_to_use; + GetCPUCoresToUse(cpu_max_freqs, policy, thread_count, &cores_to_use); + if (!cores_to_use.empty()) { + if (port::Env::Default()->SchedSetAffinity(cores_to_use) + != MaceStatus::MACE_SUCCESS) { + LOG(ERROR) << "Failed to sched_set_affinity"; + } + } + if (!cores_to_use.empty() && thread_count > cores_to_use.size()) { + thread_count = cores_to_use.size(); + } + VLOG(2) << "Use " << thread_count << " threads"; + + default_tile_count_ = thread_count; + if (cores_to_use.size() >= 2 + && cpu_max_freqs[0] != cpu_max_freqs[cores_to_use.back()]) { + default_tile_count_ = thread_count * kTileCountPerThread; + } + MACE_CHECK(default_tile_count_ > 0, "default tile count should > 0"); + + threads_ = std::vector(thread_count); + thread_infos_ = std::vector(thread_count); + for (auto &thread_info : thread_infos_) { + thread_info.cpu_cores = cores_to_use; + } +} + +ThreadPool::~ThreadPool() { + Destroy(); +} + +void ThreadPool::Init() { + VLOG(2) << "Init thread pool"; + if (threads_.size() <= 1) { + return; + } + count_down_latch_.Reset(threads_.size() - 1); + event_ = kThreadPoolInit; + for (size_t i = 1; i < threads_.size(); ++i) { + threads_[i] = std::thread(&ThreadPool::ThreadLoop, this, i); + } + count_down_latch_.Wait(); +} + +void ThreadPool::Run(const std::function &func, + size_t iterations) { + const size_t thread_count = threads_.size(); + const size_t iters_per_thread = iterations / thread_count; + const size_t remainder = iterations % thread_count; + size_t iters_offset = 0; + + std::unique_lock run_lock(run_mutex_); + + for (size_t i = 0; i < thread_count; ++i) { + size_t count = iters_per_thread + (i < remainder); + thread_infos_[i].range_start = iters_offset; + size_t range_end = std::min(iterations, iters_offset + count); + thread_infos_[i].range_end = range_end; + thread_infos_[i].range_len = range_end - iters_offset; + thread_infos_[i].func = reinterpret_cast(&func); + iters_offset += thread_infos_[i].range_len; + } + + count_down_latch_.Reset(thread_count - 1); + { + std::unique_lock m(event_mutex_); + event_.store(kThreadPoolRun | ~(event_ | kThreadPoolEventMask), + std::memory_order::memory_order_release); + event_cond_.notify_all(); + } + + ThreadRun(0); + count_down_latch_.Wait(); +} + +void ThreadPool::Destroy() { + VLOG(2) << "Destroy thread pool"; + if (threads_.size() <= 1) { + return; + } + + std::unique_lock run_lock(run_mutex_); + + count_down_latch_.Wait(); + { + std::unique_lock m(event_mutex_); + event_.store(kThreadPoolShutdown, std::memory_order::memory_order_release); + event_cond_.notify_all(); + } + + for (size_t i = 1; i < threads_.size(); ++i) { + if (threads_[i].joinable()) { + threads_[i].join(); + } else { + LOG(ERROR) << "Thread: " << threads_[i].get_id() << " not joinable" + << std::endl; + } + } +} + +// Event is executed synchronously. +void ThreadPool::ThreadLoop(size_t tid) { + if (!thread_infos_[tid].cpu_cores.empty()) { + if (port::Env::Default()->SchedSetAffinity(thread_infos_[tid].cpu_cores) + != MaceStatus::MACE_SUCCESS) { + LOG(ERROR) << "Failed to sched set affinity for tid: " << tid; + } + } + + int last_event = kThreadPoolNone; + + for (;;) { + SpinWait(event_, last_event, kThreadPoolSpinWaitTime); + if (event_.load(std::memory_order::memory_order_acquire) == last_event) { + std::unique_lock m(event_mutex_); + while (event_ == last_event) { + event_cond_.wait(m); + } + } + + int event = event_.load(std::memory_order::memory_order_acquire); + switch (event & kThreadPoolEventMask) { + case kThreadPoolInit: { + count_down_latch_.CountDown(); + break; + } + + case kThreadPoolRun: { + ThreadRun(tid); + count_down_latch_.CountDown(); + break; + } + + case kThreadPoolShutdown: return; + default: break; + } + + last_event = event; + } +} + +void ThreadPool::ThreadRun(size_t tid) { + ThreadInfo &thread_info = thread_infos_[tid]; + uintptr_t func_ptr = thread_info.func; + const std::function *func = + reinterpret_cast *>(func_ptr); + // do own work + size_t range_len; + while ((range_len = thread_info.range_len) > 0) { + if (thread_info.range_len.compare_exchange_strong(range_len, + range_len - 1)) { + func->operator()(thread_info.range_start++); + } + } + + // steal other threads' work + size_t thread_count = threads_.size(); + for (size_t t = (tid + 1) % thread_count; t != tid; + t = (t + 1) % thread_count) { + ThreadInfo &other_thread_info = thread_infos_[t]; + uintptr_t other_func_ptr = other_thread_info.func; + const std::function *other_func = + reinterpret_cast *>( + other_func_ptr); + while ((range_len = other_thread_info.range_len) > 0) { + if (other_thread_info.range_len.compare_exchange_strong(range_len, + range_len + - 1)) { + size_t tail = other_thread_info.range_end--; + other_func->operator()(tail - 1); + } + } + } +} + +void ThreadPool::Compute1D(const std::function &func, + size_t start, + size_t end, + size_t step, + size_t tile_size, + int cost_per_item) { + if (start >= end) { + return; + } + + size_t items = 1 + (end - start - 1) / step; + if (threads_.size() <= 1 || (cost_per_item >= 0 + && items * cost_per_item < kMaxCostUsingSingleThread)) { + func(start, end, step); + return; + } + + if (tile_size == 0) { + tile_size = std::max(static_cast(1), items / default_tile_count_); + } + + size_t step_tile_size = step * tile_size; + size_t tile_count = RoundUp(items, tile_size); + Run([&](size_t tile_idx) { + size_t tile_start = start + tile_idx * step_tile_size; + size_t tile_end = std::min(end, tile_start + step_tile_size); + func(tile_start, tile_end, step); + }, tile_count); +} + +void ThreadPool::Compute2D(const std::function &func, + size_t start0, + size_t end0, + size_t step0, + size_t start1, + size_t end1, + size_t step1, + size_t tile_size0, + size_t tile_size1, + int cost_per_item) { + if (start0 >= end0 || start1 >= end1) { + return; + } + + size_t items0 = 1 + (end0 - start0 - 1) / step0; + size_t items1 = 1 + (end1 - start1 - 1) / step1; + if (threads_.size() <= 1 || (cost_per_item >= 0 + && items0 * items1 * cost_per_item < kMaxCostUsingSingleThread)) { + func(start0, end0, step0, start1, end1, step1); + return; + } + + if (tile_size0 == 0 || tile_size1 == 0) { + if (items0 >= default_tile_count_) { + tile_size0 = items0 / default_tile_count_; + tile_size1 = items1; + } else { + tile_size0 = 1; + tile_size1 = std::max(static_cast(1), + items1 * items0 / default_tile_count_); + } + } + + size_t step_tile_size0 = step0 * tile_size0; + size_t step_tile_size1 = step1 * tile_size1; + size_t tile_count0 = RoundUp(items0, tile_size0); + size_t tile_count1 = RoundUp(items1, tile_size1); + + Run([&](size_t tile_idx) { + size_t tile_idx0 = tile_idx / tile_count1; + size_t tile_idx1 = tile_idx - tile_idx0 * tile_count1; + size_t tile_start0 = start0 + tile_idx0 * step_tile_size0; + size_t tile_end0 = std::min(end0, tile_start0 + step_tile_size0); + size_t tile_start1 = start1 + tile_idx1 * step_tile_size1; + size_t tile_end1 = std::min(end1, tile_start1 + step_tile_size1); + func(tile_start0, tile_end0, step0, tile_start1, tile_end1, step1); + }, tile_count0 * tile_count1); +} + +void ThreadPool::Compute3D(const std::function &func, + size_t start0, + size_t end0, + size_t step0, + size_t start1, + size_t end1, + size_t step1, + size_t start2, + size_t end2, + size_t step2, + size_t tile_size0, + size_t tile_size1, + size_t tile_size2, + int cost_per_item) { + if (start0 >= end0 || start1 >= end1 || start2 >= end1) { + return; + } + + size_t items0 = 1 + (end0 - start0 - 1) / step0; + size_t items1 = 1 + (end1 - start1 - 1) / step1; + size_t items2 = 1 + (end2 - start2 - 1) / step2; + if (threads_.size() <= 1 || (cost_per_item >= 0 + && items0 * items1 * items2 * cost_per_item + < kMaxCostUsingSingleThread)) { + func(start0, end0, step0, start1, end1, step1, start2, end2, step2); + return; + } + + if (tile_size0 == 0 || tile_size1 == 0 || tile_size2 == 0) { + if (items0 >= default_tile_count_) { + tile_size0 = items0 / default_tile_count_; + tile_size1 = items1; + tile_size2 = items2; + } else { + tile_size0 = 1; + size_t items01 = items1 * items0; + if (items01 >= default_tile_count_) { + tile_size1 = items01 / default_tile_count_; + tile_size2 = items2; + } else { + tile_size1 = 1; + tile_size2 = std::max(static_cast(1), + items01 * items2 / default_tile_count_); + } + } + } + + size_t step_tile_size0 = step0 * tile_size0; + size_t step_tile_size1 = step1 * tile_size1; + size_t step_tile_size2 = step2 * tile_size2; + size_t tile_count0 = RoundUp(items0, tile_size0); + size_t tile_count1 = RoundUp(items1, tile_size1); + size_t tile_count2 = RoundUp(items2, tile_size2); + size_t tile_count12 = tile_count1 * tile_count2; + + Run([&](size_t tile_idx) { + size_t tile_idx0 = tile_idx / tile_count12; + size_t tile_idx12 = tile_idx - tile_idx0 * tile_count12; + size_t tile_idx1 = tile_idx12 / tile_count2; + size_t tile_idx2 = tile_idx12 - tile_idx1 * tile_count2; + size_t tile_start0 = start0 + tile_idx0 * step_tile_size0; + size_t tile_end0 = std::min(end0, tile_start0 + step_tile_size0); + size_t tile_start1 = start1 + tile_idx1 * step_tile_size1; + size_t tile_end1 = std::min(end1, tile_start1 + step_tile_size1); + size_t tile_start2 = start2 + tile_idx2 * step_tile_size2; + size_t tile_end2 = std::min(end2, tile_start2 + step_tile_size2); + func(tile_start0, + tile_end0, + step0, + tile_start1, + tile_end1, + step1, + tile_start2, + tile_end2, + step2); + }, tile_count0 * tile_count12); +} + +} // namespace utils +} // namespace mace diff --git a/mace/utils/thread_pool.h b/mace/utils/thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..67fa89cf112cc3b3fc55124a73ef8bb63633e57c --- /dev/null +++ b/mace/utils/thread_pool.h @@ -0,0 +1,118 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_UTILS_THREAD_POOL_H_ +#define MACE_UTILS_THREAD_POOL_H_ + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include +#include + +#include "mace/public/mace.h" +#include "mace/port/port.h" +#include "mace/utils/count_down_latch.h" + +namespace mace { +namespace utils { + +class ThreadPool { + public: + ThreadPool(const size_t thread_count, + const CPUAffinityPolicy affinity_policy); + ~ThreadPool(); + + void Init(); + + void Run(const std::function &func, size_t iterations); + + void Compute1D(const std::function &func, + size_t start, + size_t end, + size_t step, + size_t tile_size = 0, + int cost_per_item = -1); + + void Compute2D(const std::function &func, + size_t start0, + size_t end0, + size_t step0, + size_t start1, + size_t end1, + size_t step1, + size_t tile_size0 = 0, + size_t tile_size1 = 0, + int cost_per_item = -1); + + void Compute3D(const std::function &func, + size_t start0, + size_t end0, + size_t step0, + size_t start1, + size_t end1, + size_t step1, + size_t start2, + size_t end2, + size_t step2, + size_t tile_size0 = 0, + size_t tile_size1 = 0, + size_t tile_size2 = 0, + int cost_per_item = -1); + + private: + void Destroy(); + void ThreadLoop(size_t tid); + void ThreadRun(size_t tid); + + std::atomic event_; + CountDownLatch count_down_latch_; + + std::mutex event_mutex_; + std::condition_variable event_cond_; + std::mutex run_mutex_; + + struct ThreadInfo { + size_t range_start; + std::atomic range_end; + std::atomic range_len; + uintptr_t func; + std::vector cpu_cores; + }; + std::vector thread_infos_; + std::vector threads_; + + size_t default_tile_count_; +}; + +} // namespace utils +} // namespace mace + +#endif // MACE_UTILS_THREAD_POOL_H_ diff --git a/mace/utils/thread_pool_test.cc b/mace/utils/thread_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..281b335b42b24d0b86e5c6816d03a50cb8845633 --- /dev/null +++ b/mace/utils/thread_pool_test.cc @@ -0,0 +1,108 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "mace/utils/thread_pool.h" + +namespace mace { +namespace utils { +namespace { + +class ThreadPoolTest : public ::testing::Test { + public: + ThreadPoolTest() + : thread_pool(4, CPUAffinityPolicy::AFFINITY_BIG_ONLY) { + thread_pool.Init(); + } + ThreadPool thread_pool; +}; + +void Test1D(size_t start, size_t end, size_t step, std::vector *res) { + for (size_t i = start; i < end; i += step) { + (*res)[i]++; + } +} + +void Test2D(size_t start0, size_t end0, size_t step0, + size_t start1, size_t end1, size_t step1, std::vector *res) { + for (size_t i = start0; i < end0; i += step0) { + for (size_t j = start1; j < end1; j += step1) { + (*res)[i * 100 + j]++; + } + } +} + +void Test3D(size_t start0, size_t end0, size_t step0, + size_t start1, size_t end1, size_t step1, + size_t start2, size_t end2, size_t step2, std::vector *res) { + for (size_t i = start0; i < end0; i += step0) { + for (size_t j = start1; j < end1; j += step1) { + for (size_t k = start2; k < end2; k += step2) { + (*res)[(i * 100 + j) * 100 + k]++; + } + } + } +} + +TEST_F(ThreadPoolTest, Compute1D) { + size_t test_size = 100; + std::vector actual(test_size, 0); + thread_pool.Compute1D([&](size_t start, size_t end, size_t step) { + Test1D(start, end, step, &actual); + }, 0, test_size, 2); + std::vector expected(test_size, 0); + Test1D(0, test_size, 2, &expected); + + for (size_t i = 0; i < test_size; ++i) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +TEST_F(ThreadPoolTest, Compute2D) { + size_t test_size = 100; + std::vector actual(test_size * test_size, 0); + thread_pool.Compute2D([&](size_t start0, size_t end0, size_t step0, + size_t start1, size_t end1, size_t step1) { + Test2D(start0, end0, step0, start1, end1, step1, &actual); + }, 0, test_size, 2, 0, test_size, 2); + std::vector expected(test_size * test_size, 0); + Test2D(0, test_size, 2, 0, test_size, 2, &expected); + + for (size_t i = 0; i < test_size * test_size; ++i) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +TEST_F(ThreadPoolTest, Compute3D) { + size_t test_size = 100; + std::vector actual(test_size * test_size * test_size, 0); + thread_pool.Compute3D([&](size_t start0, size_t end0, size_t step0, + size_t start1, size_t end1, size_t step1, + size_t start2, size_t end2, size_t step2) { + Test3D(start0, end0, step0, start1, end1, step1, start2, end2, step2, + &actual); + }, 0, test_size, 2, 0, test_size, 2, 0, test_size, 2); + std::vector expected(test_size * test_size * test_size, 0); + Test3D(0, test_size, 2, 0, test_size, 2, 0, test_size, 2, &expected); + + for (size_t i = 0; i < test_size * test_size * test_size; ++i) { + EXPECT_EQ(expected[i], actual[i]); + } +} + +} // namespace +} // namespace utils +} // namespace mace