From aaf95cf23645d80cd9d4d9047f7375d9a89c7aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 15 Aug 2018 15:38:52 +0800 Subject: [PATCH] Implement thread pool. --- mace/core/BUILD | 21 ++ mace/core/runtime/cpu/cpu_runtime.cc | 5 +- mace/core/runtime/cpu/cpu_runtime.h | 3 + mace/core/testing/test_benchmark_main.cc | 5 + mace/core/threadpool.cc | 271 +++++++++++++++++++++++ mace/core/threadpool.h | 89 ++++++++ mace/core/threadpool_test.cc | 123 ++++++++++ 7 files changed, 515 insertions(+), 2 deletions(-) create mode 100644 mace/core/threadpool.cc create mode 100644 mace/core/threadpool.h create mode 100644 mace/core/threadpool_test.cc diff --git a/mace/core/BUILD b/mace/core/BUILD index fd9ec5a3..a9d15bde 100644 --- a/mace/core/BUILD +++ b/mace/core/BUILD @@ -64,6 +64,7 @@ cc_library( "//mace/codegen:generated_version", "//mace/proto:mace_cc", "//mace/utils", + "@gemmlowp", ] + if_opencl_enabled([ ":opencl_headers", "//mace/codegen:generated_opencl", @@ -73,6 +74,26 @@ cc_library( ]), ) +cc_test( + name = "core_test", + srcs = glob(["*_test.cc"]), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ], + linkopts = ["-ldl"] + if_android([ + "-pie", + "-lm", + ]), + deps = [ + ":core", + "//mace/utils", + "@gtest", + "@gtest//:gtest_main", + ], +) + cc_library( name = "opencl_headers", hdrs = glob([ diff --git a/mace/core/runtime/cpu/cpu_runtime.cc b/mace/core/runtime/cpu/cpu_runtime.cc index 5bef3805..1802cac4 100644 --- a/mace/core/runtime/cpu/cpu_runtime.cc +++ b/mace/core/runtime/cpu/cpu_runtime.cc @@ -75,6 +75,9 @@ int GetCPUMaxFreq(int cpu_id) { return freq; } +} // namespace + + MaceStatus SetThreadAffinity(cpu_set_t mask) { #if defined(__ANDROID__) pid_t pid = gettid(); @@ -90,8 +93,6 @@ MaceStatus SetThreadAffinity(cpu_set_t mask) { } } -} // namespace - MaceStatus GetCPUBigLittleCoreIDs(std::vector *big_core_ids, std::vector *little_core_ids) { MACE_CHECK_NOTNULL(big_core_ids); diff --git a/mace/core/runtime/cpu/cpu_runtime.h b/mace/core/runtime/cpu/cpu_runtime.h index 333729e1..821ce7fe 100644 --- a/mace/core/runtime/cpu/cpu_runtime.h +++ b/mace/core/runtime/cpu/cpu_runtime.h @@ -15,6 +15,7 @@ #ifndef MACE_CORE_RUNTIME_CPU_CPU_RUNTIME_H_ #define MACE_CORE_RUNTIME_CPU_CPU_RUNTIME_H_ +#include #include #include "mace/public/mace.h" @@ -31,6 +32,8 @@ MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads, MaceStatus SetOpenMPThreadsAndAffinityPolicy(int omp_num_threads_hint, CPUAffinityPolicy policy); +MaceStatus SetThreadAffinity(cpu_set_t mask); + } // namespace mace #endif // MACE_CORE_RUNTIME_CPU_CPU_RUNTIME_H_ diff --git a/mace/core/testing/test_benchmark_main.cc b/mace/core/testing/test_benchmark_main.cc index e730c10e..864c7816 100644 --- a/mace/core/testing/test_benchmark_main.cc +++ b/mace/core/testing/test_benchmark_main.cc @@ -16,6 +16,7 @@ #include "gflags/gflags.h" #include "mace/core/runtime/cpu/cpu_runtime.h" +#include "mace/core/threadpool.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/testing/test_benchmark.h" #include "mace/public/mace.h" @@ -36,6 +37,10 @@ int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // config runtime + // TODO(liyin): Uncomment the following after removing openmp. +// mace::ThreadPoolRegister::ConfigThreadPool(FLAGS_omp_num_threads, +// static_cast(FLAGS_cpu_affinity_policy)); + mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy( FLAGS_omp_num_threads, static_cast(FLAGS_cpu_affinity_policy)); diff --git a/mace/core/threadpool.cc b/mace/core/threadpool.cc new file mode 100644 index 00000000..a03a418c --- /dev/null +++ b/mace/core/threadpool.cc @@ -0,0 +1,271 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mace/core/threadpool.h" +#include "mace/utils/logging.h" + +namespace mace { + +namespace { +class InitTask : public gemmlowp::Task { + public: + explicit InitTask(const std::vector &cpu_ids) + : Task(), cpu_ids_(cpu_ids) {} + void Run() { + cpu_set_t mask; + CPU_ZERO(&mask); + for (auto cpu_id : cpu_ids_) { + CPU_SET(cpu_id, &mask); + } + SetThreadAffinity(mask); + } + + private: + const std::vector cpu_ids_; +}; + +class BlockTask : public gemmlowp::Task { + public: + BlockTask(const int64_t start, + const int64_t end, + const std::function &fn) + : Task(), + start_(start), + end_(end), + fn_(fn) {} + + void Run() { + fn_(start_, end_); + } + + private: + const int64_t start_; + const int64_t end_; + const std::function fn_; +}; +} // namespace + +ThreadPool::ThreadPool(int num_threads, CPUAffinityPolicy policy) + : num_threads_(num_threads), busy_(false) { + std::vector big_core_ids; + std::vector little_core_ids; + std::vector use_cpu_ids; + MaceStatus res = GetCPUBigLittleCoreIDs(&big_core_ids, &little_core_ids); + if (res == MaceStatus::MACE_SUCCESS) { + if (policy == CPUAffinityPolicy::AFFINITY_BIG_ONLY) { + use_cpu_ids = std::move(big_core_ids); + } else { + use_cpu_ids = std::move(little_core_ids); + } + + if (num_threads <= 0 || + num_threads > static_cast(use_cpu_ids.size())) { + num_threads_ = use_cpu_ids.size(); + } + } else { + LOG(WARNING) << "Failed to get cpu big little cores info"; + } + + // if set mask failed and user set num_threads as -1, use single thread + if (num_threads_ < 0) { + num_threads_ = 1; + } + + VLOG(2) << "Use thread count: " << num_threads_; + + std::vector init_tasks(num_threads_); + for (int i = 0; i < num_threads_; ++i) { + init_tasks[i] = new InitTask(use_cpu_ids); + } + workers_pool_.Execute(init_tasks); +} + +ThreadPool::~ThreadPool() {} + +void ThreadPool::ParallelRun(const int64_t total, + std::function fn) { + if (total <= 0) { + return; + } else if (total == 1 || busy_) { + fn(0, total); + return; + } + busy_ = true; + + const int64_t + shards = std::min(total, static_cast(num_threads_)); + const int64_t work_size_per_shard = total / shards; + const int64_t remain = total - shards * work_size_per_shard; + const int64_t work_size_per_shard_plus_one = work_size_per_shard + 1; + + std::vector tasks(shards); + + int64_t start = 0; + int64_t end = 0; + for (int64_t i = 0; i < remain; ++i) { + end = start + work_size_per_shard_plus_one; + tasks[i] = new BlockTask(start, end, fn); + start = end; + } + for (int64_t i = remain; i < shards; ++i) { + end = start + work_size_per_shard; + tasks[i] = new BlockTask(start, end, fn); + start = end; + } + + workers_pool_.Execute(tasks); + busy_ = false; +} + +void ThreadPool::ParallelFor(const int64_t start, + const int64_t end, + const int64_t step, + std::function fn) { + MACE_CHECK(start <= end && step > 0, "start must be le end and step gt 0"); + const int64_t total = (end - start + step - 1) / step; + ParallelRun(total, [&](const int64_t s, const int64_t t) { + const int64_t start_i = start + s * step; + const int64_t end_i = std::min(end, start + t * step); + for (int64_t i = start_i; i < end_i; i += step) { + fn(i); + } + }); +} + +void ThreadPool::ParallelFor(const int64_t start1, + const int64_t end1, + const int64_t step1, + const int64_t start2, + const int64_t end2, + const int64_t step2, + std::function fn) { + MACE_CHECK(start1 <= end1 && step1 > 0 && start2 <= end2 && step2 > 0, + "start must be le end and step gt 0"); + + const int64_t total1 = ((end1 - start1 + step1 - 1) / step1); + const int64_t total2 = ((end2 - start2 + step2 - 1) / step2); + const int64_t total = total1 * total2; + + if (total == 0) { + return; + } else if (total1 == 1) { + ParallelFor(start2, + end2, + step2, + [&](const int64_t arg2) { + fn(start1, arg2); + }); + } else if (total1 % num_threads_ == 0) { + ParallelFor(start1, + end1, + step1, + [&](const int64_t arg1) { + for (int64_t i = start2; i < end2; i += step2) { + fn(arg1, i); + } + }); + } else { + ParallelRun(total, [&](const int64_t s, const int64_t t) { + for (int64_t idx = s; idx < t; ++idx) { + const int64_t i = idx / total2; + const int64_t j = idx - i * total2; + fn(start1 + step1 * i, start2 + step2 * j); + } + }); + } +} + +void ThreadPool::ParallelFor(const int64_t start1, + const int64_t end1, + const int64_t step1, + const int64_t start2, + const int64_t end2, + const int64_t step2, + const int64_t start3, + const int64_t end3, + const int64_t step3, + std::function fn) { + MACE_CHECK(start1 <= end1 && step1 > 0 && start2 <= end2 && step2 > 0 + && start3 <= end3 && step3 > 0, + "start must be le end and step gt 0"); + + const int64_t total1 = ((end1 - start1 + step1 - 1) / step1); + const int64_t total2 = ((end2 - start2 + step2 - 1) / step2); + const int64_t total3 = ((end3 - start3 + step3 - 1) / step3); + const int64_t total23 = total2 * total3; + const int64_t total = total1 * total23; + + if (total == 0) { + return; + } else if (total1 == 1) { + ParallelFor(start2, + end2, + step2, + start3, + end3, + step3, + [&](const int64_t arg2, const int64_t arg3) { + fn(start1, arg2, arg3); + }); + } else if ((total1 * total2) % num_threads_ == 0) { + ParallelFor(start1, + end1, + step1, + start2, + end2, + step2, + [&](const int64_t arg1, const int64_t arg2) { + for (int64_t i = start3; i < end3; i += step3) { + fn(arg1, arg2, i); + } + }); + } else { + ParallelRun(total, [&](int64_t s, int64_t t) { + for (int64_t idx = s; idx < t; ++idx) { + const int64_t i = idx / total23; + const int64_t mod_i = idx - i * total23; + const int64_t j = mod_i / total3; + const int64_t k = mod_i - j * total3; + fn(start1 + step1 * i, start2 + step2 * j, start3 + step3 * k); + } + }); + } +} + +ThreadPool *ThreadPoolRegister::thread_pool = nullptr; + +ThreadPool *ThreadPoolRegister::GetThreadPool() { + if (thread_pool == nullptr) { + ConfigThreadPool(-1, CPUAffinityPolicy::AFFINITY_NONE); + } + return thread_pool; +} + +void ThreadPoolRegister::ConfigThreadPool(int num_threads, + CPUAffinityPolicy policy) { + MACE_CHECK(thread_pool == nullptr, + "ThreadPool has already been initialized."); + thread_pool = new ThreadPool(num_threads, policy); +} + +} // namespace mace diff --git a/mace/core/threadpool.h b/mace/core/threadpool.h new file mode 100644 index 00000000..2b488e9d --- /dev/null +++ b/mace/core/threadpool.h @@ -0,0 +1,89 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_CORE_THREADPOOL_H_ +#define MACE_CORE_THREADPOOL_H_ + +#include "mace/core/runtime/cpu/cpu_runtime.h" + +// Remove the following macros after removing openmp dependency +#ifdef GEMMLOWP_USE_OPENMP +#define MACE_GEMMLOWP_USE_OPENMP +#undef GEMMLOWP_USE_OPENMP +#endif +#include "internal/multi_thread_gemm.h" +#ifdef MACE_GEMMLOWP_USE_OPENMP +#define GEMMLOWP_USE_OPENMP +#undef MACE_GEMMLOWP_USE_OPENMP +#endif + +namespace mace { + +class ThreadPool { + public: + explicit ThreadPool(int num_threads, CPUAffinityPolicy policy); + + ~ThreadPool(); + + void ParallelRun(const int64_t total, + std::function fn); + + // Parallel for + void ParallelFor(const int64_t start, const int64_t end, const int64_t step, + std::function fn); + + // Parallel for collapse(2) + void ParallelFor(const int64_t start1, + const int64_t end1, + const int64_t step1, + const int64_t start2, + const int64_t end2, + const int64_t step2, + std::function fn); + + // Parallel for collapse(3) + void ParallelFor(const int64_t start1, + const int64_t end1, + const int64_t step1, + const int64_t start2, + const int64_t end2, + const int64_t step2, + const int64_t start3, + const int64_t end3, + const int64_t step3, + std::function fn); + + private: + int num_threads_; + bool busy_; + gemmlowp::WorkersPool workers_pool_; +}; + +class ThreadPoolRegister { + public: + static ThreadPool *GetThreadPool(); + + static void ConfigThreadPool(int num_threads, CPUAffinityPolicy policy); + + private: + static ThreadPool *thread_pool; +}; + +} // namespace mace + + + +#endif // MACE_CORE_THREADPOOL_H_ diff --git a/mace/core/threadpool_test.cc b/mace/core/threadpool_test.cc new file mode 100644 index 00000000..10fdd27a --- /dev/null +++ b/mace/core/threadpool_test.cc @@ -0,0 +1,123 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/threadpool.h" +#include "mace/utils/logging.h" + +namespace mace { +namespace kernels { +namespace test { + +namespace { + +void TestParallelRun(int64_t size) { + ThreadPool thread_pool(4, CPUAffinityPolicy::AFFINITY_NONE); + std::vector input(size); + thread_pool.ParallelRun(size, [&](const int64_t start, const int64_t end) { + for (int64_t i = start; i < end; ++i) { + input[i] = i; + } + }); + + for (int64_t i = 0; i < size; ++i) { + EXPECT_EQ(input[i], i); + } +} + +void TestParallelForLoop1(int64_t start, int64_t end, int64_t step) { + ThreadPool thread_pool(4, CPUAffinityPolicy::AFFINITY_NONE); + std::vector input(end); + thread_pool.ParallelFor(start, end, step, [&](const int64_t i) { + input[i] = i; + }); + + for (int64_t i = start; i < end; i += step) { + EXPECT_EQ(input[i], i); + } +} + +void TestParallelForLoop2(int64_t start1, int64_t end1, int64_t step1, + int64_t start2, int64_t end2, int64_t step2) { + ThreadPool thread_pool(4, CPUAffinityPolicy::AFFINITY_BIG_ONLY); + std::vector> input(end1, std::vector(end2)); + thread_pool.ParallelFor(start1, end1, step1, + start2, end2, step2, + [&](const int64_t i, const int64_t j) { + input[i][j] = i * j; + }); + + for (int64_t i = start1; i < end1; i += step1) { + for (int64_t j = start2; j < end2; j += step2) { + EXPECT_EQ(input[i][j], i * j); + } + } +} + +void TestParallelForLoop3(int64_t start1, int64_t end1, int64_t step1, + int64_t start2, int64_t end2, int64_t step2, + int64_t start3, int64_t end3, int64_t step3) { + ThreadPool thread_pool(4, CPUAffinityPolicy::AFFINITY_BIG_ONLY); + std::vector>> + input(end1, + std::vector>(end2, + std::vector(end3))); + thread_pool.ParallelFor(start1, end1, step1, + start2, end2, step2, + start3, end3, step3, + [&](const int64_t i, + const int64_t j, + const int64_t k) { + input[i][j][k] = i * j * k; + }); + + for (int64_t i = start1; i < end1; i += step1) { + for (int64_t j = start2; j < end2; j += step2) { + for (int64_t k = start3; k < end3; k += step3) { + EXPECT_EQ(input[i][j][k], i * j * k); + } + } + } +} + +} // namespace + +TEST(ThreadPoolTest, TestParallelRun) { + TestParallelRun(102); +} + +TEST(ThreadPoolTest, TestParallelFor1) { + TestParallelForLoop1(1, 102, 2); +} + +TEST(ThreadPoolTest, TestParallelFor2) { + TestParallelForLoop2(1, 1, 2, 2, 53, 4); + TestParallelForLoop2(1, 102, 2, 2, 53, 4); + TestParallelForLoop2(1, 2, 2, 2, 53, 4); + TestParallelForLoop2(1, 101, 1, 2, 53, 4); +} + +TEST(ThreadPoolTest, TestParallelFor3) { + TestParallelForLoop3(1, 1, 2, 2, 53, 4, 3, 31, 3); + TestParallelForLoop3(1, 102, 2, 2, 53, 4, 3, 31, 3); + TestParallelForLoop3(1, 2, 2, 2, 53, 4, 3, 31, 3); + TestParallelForLoop3(1, 5, 1, 2, 53, 4, 3, 31, 3); + TestParallelForLoop3(1, 5, 2, 2, 50, 24, 3, 31, 3); +} + +} // namespace test +} // namespace kernels +} // namespace mace -- GitLab