From 32b55573292b40452f99616db40f2f8c8d7809da Mon Sep 17 00:00:00 2001 From: liaogang Date: Fri, 23 Sep 2016 21:43:53 +0800 Subject: [PATCH] Add thread Barrier unit test --- paddle/utils/tests/CMakeLists.txt | 1 + paddle/utils/tests/test_ThreadBarrier.cpp | 68 +++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 paddle/utils/tests/test_ThreadBarrier.cpp diff --git a/paddle/utils/tests/CMakeLists.txt b/paddle/utils/tests/CMakeLists.txt index 5b31cd393dd..51f18893928 100644 --- a/paddle/utils/tests/CMakeLists.txt +++ b/paddle/utils/tests/CMakeLists.txt @@ -3,6 +3,7 @@ add_simple_unittest(test_Logging) add_simple_unittest(test_Thread) add_simple_unittest(test_StringUtils) add_simple_unittest(test_CustomStackTrace) +add_simple_unittest(test_ThreadBarrier) add_executable( test_CustomStackTracePrint diff --git a/paddle/utils/tests/test_ThreadBarrier.cpp b/paddle/utils/tests/test_ThreadBarrier.cpp new file mode 100644 index 00000000000..241cdda7bd1 --- /dev/null +++ b/paddle/utils/tests/test_ThreadBarrier.cpp @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/utils/Logging.h" +#include "paddle/utils/CommandLineParser.h" +#include "paddle/utils/Util.h" +#include "paddle/utils/Locks.h" + +P_DEFINE_int32(test_thread_num, 100, "testing thread number"); + +void testNormalImpl(size_t thread_num, + const std::function&, + paddle::ThreadBarrier&)>& callback) { + std::mutex mutex; + std::set tids; + paddle::ThreadBarrier barrier(thread_num); + + std::vector threads; + threads.reserve(thread_num); + for (int32_t i = 0; i < thread_num; ++i) { + threads.emplace_back([&thread_num, &mutex, + &tids, &barrier, &callback]{ + callback(thread_num, mutex, tids, barrier); + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ThreadBarrier, normalTest) { + for (auto &thread_num : {10, 30, 50 , 100 , 300, 1000}) { + testNormalImpl(thread_num, + [](size_t thread_num, std::mutex& mutex, + std::set& tids, + paddle::ThreadBarrier& barrier){ + { + std::lock_guard guard(mutex); + tids.insert(std::this_thread::get_id()); + } + barrier.wait(); + // Check whether all threads reach this point or not + CHECK_EQ(tids.size(), thread_num); + }); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +} -- GitLab