diff --git a/paddle/utils/tests/CMakeLists.txt b/paddle/utils/tests/CMakeLists.txt index 5b31cd393dd1fc319be0ae9a5811f5637617e08d..51f18893928455308a2331fa5061f9849019432c 100644 --- a/paddle/utils/tests/CMakeLists.txt +++ b/paddle/utils/tests/CMakeLists.txt @@ -3,6 +3,7 @@ add_simple_unittest(test_Logging) add_simple_unittest(test_Thread) add_simple_unittest(test_StringUtils) add_simple_unittest(test_CustomStackTrace) +add_simple_unittest(test_ThreadBarrier) add_executable( test_CustomStackTracePrint diff --git a/paddle/utils/tests/test_ThreadBarrier.cpp b/paddle/utils/tests/test_ThreadBarrier.cpp new file mode 100644 index 0000000000000000000000000000000000000000..241cdda7bd1c90335e85c7a559afd0c84c255009 --- /dev/null +++ b/paddle/utils/tests/test_ThreadBarrier.cpp @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/utils/Logging.h" +#include "paddle/utils/CommandLineParser.h" +#include "paddle/utils/Util.h" +#include "paddle/utils/Locks.h" + +P_DEFINE_int32(test_thread_num, 100, "testing thread number"); + +void testNormalImpl(size_t thread_num, + const std::function&, + paddle::ThreadBarrier&)>& callback) { + std::mutex mutex; + std::set tids; + paddle::ThreadBarrier barrier(thread_num); + + std::vector threads; + threads.reserve(thread_num); + for (int32_t i = 0; i < thread_num; ++i) { + threads.emplace_back([&thread_num, &mutex, + &tids, &barrier, &callback]{ + callback(thread_num, mutex, tids, barrier); + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ThreadBarrier, normalTest) { + for (auto &thread_num : {10, 30, 50 , 100 , 300, 1000}) { + testNormalImpl(thread_num, + [](size_t thread_num, std::mutex& mutex, + std::set& tids, + paddle::ThreadBarrier& barrier){ + { + std::lock_guard guard(mutex); + tids.insert(std::this_thread::get_id()); + } + barrier.wait(); + // Check whether all threads reach this point or not + CHECK_EQ(tids.size(), thread_num); + }); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + paddle::initMain(argc, argv); + return RUN_ALL_TESTS(); +}