// Copyright (c) 2022 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/utils/multi_threading.h" #include #include #include #include #include #include "paddle/cinn/utils/string.h" namespace cinn { namespace utils { SequenceDispatcher::SequenceDispatcher(int begin, int end, int step) : end_(end), step_(step), index_(begin) { CHECK_LE(begin, end) << StringFormat("begin[%d] > end[%d]", begin, end); CHECK_GT(step, 0) << "step is less than 0"; } int SequenceDispatcher::Next() const { int idx = -1; if (index_ >= end_ || (idx = index_.fetch_add(step_)) >= end_) { return -1; } return idx; } void parallel_run(const WorkerFuncType& fn, JobDispatcher&& dispatcher, int num_threads) { if (num_threads == -1 || num_threads > std::thread::hardware_concurrency()) { num_threads = std::thread::hardware_concurrency(); } CHECK_GT(num_threads, 0) << "num_threads should be greater than 0"; // worker function of a thread auto worker = [&fn, &dispatcher](int tid) -> int { int index = -1, counter = 0; while ((index = dispatcher.Next()) != -1) { VLOG(5) << "Thread-" << tid << " process at index: " << index; fn(index); ++counter; } return counter; }; std::vector> futures; std::vector threads; // The first thread runs inplace, and other `num_threads - 1` threads launched // with std::future to run asynchronously if (num_threads > 1) { futures.reserve(num_threads - 1); threads.reserve(num_threads - 1); for (int tid = 1; tid < num_threads; ++tid) { std::packaged_task task(worker); futures.emplace_back(task.get_future()); threads.emplace_back(std::move(task), tid); } } // wait results and catch exceptions try { int tid = 0; int counter = worker(tid); VLOG(4) << "Thread-0 process " << counter << " tasks."; for (auto&& future : futures) { counter = future.get(); ++tid; VLOG(4) << "Thread-" << tid << " process " << counter << " tasks."; } } catch (const std::exception& e) { LOG(FATAL) << "parallel_run incurs error: " << e.what(); } // join threads for (auto&& thread : threads) { thread.join(); } } } // namespace utils } // namespace cinn