From 9e11ca8096a57cda6d91741c064b362180ff2a50 Mon Sep 17 00:00:00 2001 From: gangliao Date: Mon, 10 Oct 2016 10:15:07 +0800 Subject: [PATCH] Use C++ 11 atomic_flag in MacOS as spin lock (#175) * Use C++ 11 atomic_flag in MacOS as spin lock * Add unittest for it. --- paddle/trainer/tests/test_CompareSparse.cpp | 2 +- paddle/utils/arch/osx/Locks.cpp | 24 ++------- paddle/utils/tests/CMakeLists.txt | 1 + paddle/utils/tests/test_SpinLock.cpp | 57 +++++++++++++++++++++ 4 files changed, 63 insertions(+), 21 deletions(-) create mode 100644 paddle/utils/tests/test_SpinLock.cpp diff --git a/paddle/trainer/tests/test_CompareSparse.cpp b/paddle/trainer/tests/test_CompareSparse.cpp index ff37d7b364..311dd333a1 100644 --- a/paddle/trainer/tests/test_CompareSparse.cpp +++ b/paddle/trainer/tests/test_CompareSparse.cpp @@ -57,7 +57,7 @@ std::vector trainerOnePassTest(const string& configFile, << " sparseUpdate=" << sparseUpdate; srand(FLAGS_seed); *ThreadLocalRand::getSeed() = FLAGS_seed; - + ThreadLocalRandomEngine::get().seed(FLAGS_seed); if (useGpu) { CHECK_LE(trainerCount, gNumDevices); } diff --git a/paddle/utils/arch/osx/Locks.cpp b/paddle/utils/arch/osx/Locks.cpp index 44bab7198d..b3ec454976 100644 --- a/paddle/utils/arch/osx/Locks.cpp +++ b/paddle/utils/arch/osx/Locks.cpp @@ -15,12 +15,9 @@ limitations under the License. */ #include "paddle/utils/Locks.h" #include "paddle/utils/Logging.h" #include +#include #include -#if MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12 -#include -#endif - namespace paddle { class SemaphorePrivate { @@ -55,12 +52,7 @@ void Semaphore::post() { class SpinLockPrivate { public: -#if MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12 - os_unfair_lock lock_; -#else - SpinLockPrivate(): lock_(OS_SPINLOCK_INIT) {} - OSSpinLock lock_; -#endif + std::atomic_flag lock_ = ATOMIC_FLAG_INIT; char padding_[64 - sizeof(lock_)]; // Padding to cache line size }; @@ -68,19 +60,11 @@ SpinLock::SpinLock(): m(new SpinLockPrivate()) {} SpinLock::~SpinLock() { delete m; } void SpinLock::lock() { -#if MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12 - os_unfair_lock_lock(&m->lock_); -#else - OSSpinLockLock(&m->lock_); -#endif + while (m->lock_.test_and_set(std::memory_order_acquire)) {} } void SpinLock::unlock() { -#if MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12 - os_unfair_lock_unlock(&m->lock_); -#else - OSSpinLockUnlock(&m->lock_); -#endif + m->lock_.clear(std::memory_order_release); } diff --git a/paddle/utils/tests/CMakeLists.txt b/paddle/utils/tests/CMakeLists.txt index 51f1889392..adf489fafe 100644 --- a/paddle/utils/tests/CMakeLists.txt +++ b/paddle/utils/tests/CMakeLists.txt @@ -4,6 +4,7 @@ add_simple_unittest(test_Thread) add_simple_unittest(test_StringUtils) add_simple_unittest(test_CustomStackTrace) add_simple_unittest(test_ThreadBarrier) +add_simple_unittest(test_SpinLock) add_executable( test_CustomStackTracePrint diff --git a/paddle/utils/tests/test_SpinLock.cpp b/paddle/utils/tests/test_SpinLock.cpp new file mode 100644 index 0000000000..ebc84e0f52 --- /dev/null +++ b/paddle/utils/tests/test_SpinLock.cpp @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/utils/Logging.h" +#include "paddle/utils/CommandLineParser.h" +#include "paddle/utils/Util.h" +#include "paddle/utils/Locks.h" + +P_DEFINE_int32(test_thread_num, 100, "testing thread number"); + +void testNormalImpl(size_t thread_num, const std::function + & callback) { + paddle::SpinLock mutex; + std::vector threads; + threads.reserve(thread_num); + + size_t count = 0; + for (size_t i = 0; i < thread_num; ++i) { + threads.emplace_back([&thread_num, &count, &mutex, &callback]{ + callback(thread_num, count, mutex); + }); + } + for (auto& thread : threads) { + thread.join(); + } + // Check whether all threads reach this point or not + CHECK_EQ(count, thread_num); +} + +TEST(ThreadSpinLock, normalTest) { + for (auto &thread_num : {10, 30, 50 , 100 , 300, 1000}) { + testNormalImpl(thread_num, [](size_t thread_num, + size_t& count, paddle::SpinLock& mutex){ + std::lock_guard lock(mutex); + ++count; + }); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} -- GitLab