diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 46c3dcca12a2751831119d1cce2fcefbd4c58ceb..6fd4095d0d28f7478b59ca6a8ab2afb02e4161f7 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -225,10 +225,6 @@ cc_test( SRCS reader_test.cc DEPS reader) -cc_library( - threadpool - SRCS threadpool.cc - DEPS enforce) cc_test( threadpool_test SRCS threadpool_test.cc diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 88ffeb59503d3566944c0ebcc3e861a4ba96cf97..a54110add67a8a583013efafb2704e8ba2e78ec0 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -155,7 +155,7 @@ void Scope::DeleteScope(Scope* scope) const { if (FLAGS_benchmark || FLAGS_eager_delete_scope) { delete scope; } else { - Async([scope] { delete scope; }); + phi::Async([scope] { delete scope; }); } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 7fecf07475b1457bf6d19279493517c55f057194..ac0ee6b9333127cb3631a77738ccc68decda5f4e 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,138 +14,16 @@ limitations under the License. */ #pragma once -#include // NOLINT -#include -#include // NOLINT -#include -#include // NOLINT -#include -#include // NOLINT -#include -#include - -#include "glog/logging.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/phi/core/threadpool.h" namespace paddle { namespace framework { -struct ExceptionHandler { - mutable std::future> future_; - explicit ExceptionHandler( - std::future>&& f) - : future_(std::move(f)) {} - void operator()() const { - auto ex = this->future_.get(); - if (ex != nullptr) { - PADDLE_THROW(platform::errors::Fatal( - "The exception is thrown inside the thread pool. You " - "should use RunAndGetException to handle the exception." - "The exception is:\n %s.", - ex->what())); - } - } -}; - -// ThreadPool maintains a queue of tasks, and runs them using a fixed -// number of threads. -class ThreadPool { - public: - explicit ThreadPool(int num_threads); - - using Task = std::packaged_task()>; - - // Returns the singleton of ThreadPool. - static ThreadPool* GetInstance(); - - ~ThreadPool(); - - // Run pushes a function to the task queue and returns a std::future - // object. To wait for the completion of the task, call - // std::future::wait(). - template - std::future Run(Callback fn) { - auto f = this->RunAndGetException(fn); - return std::async(std::launch::deferred, ExceptionHandler(std::move(f))); - } - - template - std::future> RunAndGetException( - Callback fn) { - Task task([fn]() -> std::unique_ptr { - try { - fn(); - } catch (platform::EnforceNotMet& ex) { - return std::unique_ptr( - new platform::EnforceNotMet(ex)); - } catch (const std::exception& e) { - PADDLE_THROW(platform::errors::Fatal( - "Unexpected exception is catched in thread pool. All " - "throwable exception in Paddle should be an EnforceNotMet." - "The exception is:\n %s.", - e.what())); - } - return nullptr; - }); - std::future> f = task.get_future(); - { - std::unique_lock lock(mutex_); - if (!running_) { - PADDLE_THROW(platform::errors::Unavailable( - "Task is enqueued into stopped ThreadPool.")); - } - tasks_.push(std::move(task)); - } - scheduled_.notify_one(); - return f; - } - - private: - DISABLE_COPY_AND_ASSIGN(ThreadPool); - - // The constructor starts threads to run TaskLoop, which retrieves - // and runs tasks from the queue. - void TaskLoop(); - - // Init is called by GetInstance. - static void Init(); - - private: - static std::unique_ptr threadpool_; - static std::once_flag init_flag_; - - std::vector> threads_; - - std::queue tasks_; - std::mutex mutex_; - bool running_; - std::condition_variable scheduled_; -}; - -class ThreadPoolIO : ThreadPool { - public: - static ThreadPool* GetInstanceIO(); - static void InitIO(); - - private: - // NOTE: threadpool in base will be inhereted here. - static std::unique_ptr io_threadpool_; - static std::once_flag io_init_flag_; -}; +using ExceptionHandler = phi::ExceptionHandler; -// Run a function asynchronously. -// NOTE: The function must return void. If the function need to return a value, -// you can use lambda to capture a value pointer. -template -std::future Async(Callback callback) { - return ThreadPool::GetInstance()->Run(callback); -} +using ThreadPool = phi::ThreadPool; -template -std::future AsyncIO(Callback callback) { - return ThreadPoolIO::GetInstanceIO()->Run(callback); -} +using ThreadPoolIO = phi::ThreadPoolIO; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 59fc31c485f3b1f96e1a34cbfc848235cbcea602..25155a0f7e87c2520111f198407c58a662253ef7 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -26,7 +26,7 @@ void do_sum(std::vector>* fs, int cnt) { for (int i = 0; i < cnt; ++i) { std::lock_guard l(*mu); - fs->push_back(framework::Async([sum]() { sum->fetch_add(1); })); + fs->push_back(phi::Async([sum]() { sum->fetch_add(1); })); } } diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index d34f5f658b87b27d5c8ef54034536b687f674d4c..90f5d38bfc93b0559e20a5aadfda8e7affa4cd4b 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -43,6 +43,10 @@ cc_library( lod_utils SRCS lod_utils.cc DEPS phi_enforce) +cc_library( + threadpool + SRCS threadpool.cc + DEPS phi_enforce) cc_library( dense_tensor diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/phi/core/threadpool.cc similarity index 71% rename from paddle/fluid/framework/threadpool.cc rename to paddle/phi/core/threadpool.cc index 1a1e017b59e097abd1826627fafe43eefe0a1c95..db1f3091031fc156ae4fcea96864b9c395ea6aef 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/phi/core/threadpool.cc @@ -1,33 +1,32 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/fluid/framework/threadpool.h" +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/threadpool.h" #include #include "gflags/gflags.h" #include "glog/logging.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/enforce.h" +DECLARE_int32(dist_threadpool_size); DEFINE_int32(io_threadpool_size, 100, "number of threads used for doing IO, default 100"); -DECLARE_int32(dist_threadpool_size); +namespace phi { -namespace paddle { -namespace framework { std::unique_ptr ThreadPool::threadpool_(nullptr); std::once_flag ThreadPool::init_flag_; @@ -47,7 +46,7 @@ void ThreadPool::Init() { PADDLE_ENFORCE_GT( num_threads, 0, - platform::errors::InvalidArgument("The number of threads is 0.")); + phi::errors::InvalidArgument("The number of threads is 0.")); threadpool_.reset(new ThreadPool(num_threads)); } } @@ -88,8 +87,8 @@ void ThreadPool::TaskLoop() { } if (tasks_.empty()) { - PADDLE_THROW(platform::errors::Unavailable( - "Current thread has no task to Run.")); + PADDLE_THROW( + phi::errors::Unavailable("Current thread has no task to Run.")); } // pop a task from the task queue @@ -115,6 +114,4 @@ void ThreadPoolIO::InitIO() { io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); } } - -} // namespace framework -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/core/threadpool.h b/paddle/phi/core/threadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..b45991f9a7f825fdc3d98de27e26bfd2972d4190 --- /dev/null +++ b/paddle/phi/core/threadpool.h @@ -0,0 +1,150 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT +#include +#include // NOLINT +#include +#include // NOLINT +#include +#include // NOLINT +#include +#include + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace phi { + +struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + PADDLE_THROW(phi::errors::Fatal( + "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception." + "The exception is:\n %s.", + ex->what())); + } + } +}; + +// ThreadPool maintains a queue of tasks, and runs them using a fixed +// number of threads. +class ThreadPool { + public: + explicit ThreadPool(int num_threads); + + using Task = + std::packaged_task()>; + + // Returns the singleton of ThreadPool. + static ThreadPool* GetInstance(); + + ~ThreadPool(); + + // Run pushes a function to the task queue and returns a std::future + // object. To wait for the completion of the task, call + // std::future::wait(). + template + std::future Run(Callback fn) { + auto f = this->RunAndGetException(fn); + return std::async(std::launch::deferred, ExceptionHandler(std::move(f))); + } + + template + std::future> RunAndGetException( + Callback fn) { + Task task([fn]() -> std::unique_ptr { + try { + fn(); + } catch (phi::enforce::EnforceNotMet& ex) { + return std::unique_ptr( + new phi::enforce::EnforceNotMet(ex)); + } catch (const std::exception& e) { + PADDLE_THROW(phi::errors::Fatal( + "Unexpected exception is catched in thread pool. All " + "throwable exception in Paddle should be an EnforceNotMet." + "The exception is:\n %s.", + e.what())); + } + return nullptr; + }); + std::future> f = + task.get_future(); + { + std::unique_lock lock(mutex_); + if (!running_) { + PADDLE_THROW(phi::errors::Unavailable( + "Task is enqueued into stopped ThreadPool.")); + } + tasks_.push(std::move(task)); + } + scheduled_.notify_one(); + return f; + } + + private: + DISABLE_COPY_AND_ASSIGN(ThreadPool); + + // The constructor starts threads to run TaskLoop, which retrieves + // and runs tasks from the queue. + void TaskLoop(); + + // Init is called by GetInstance. + static void Init(); + + private: + static std::unique_ptr threadpool_; + static std::once_flag init_flag_; + + std::vector> threads_; + + std::queue tasks_; + std::mutex mutex_; + bool running_; + std::condition_variable scheduled_; +}; + +class ThreadPoolIO : ThreadPool { + public: + static ThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + +// Run a function asynchronously. +// NOTE: The function must return void. If the function need to return a value, +// you can use lambda to capture a value pointer. +template +std::future Async(Callback callback) { + return ThreadPool::GetInstance()->Run(callback); +} + +template +std::future AsyncIO(Callback callback) { + return ThreadPoolIO::GetInstanceIO()->Run(callback); +} + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index 2e7fe555feffc5deb29d8dcada7e9d4974ee5e12..de8b4eae4f6600879d1fc0eb19a19ab365d21bb6 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -15,10 +15,10 @@ #include "paddle/phi/kernels/selected_rows/adam_kernel.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/threadpool.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/threadpool.h" #include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h" @@ -201,12 +201,12 @@ void AdamDenseParamSparseGradKernel( if (end > static_cast(param_row_count)) { end = static_cast(param_row_count); } - fs.push_back(paddle::framework::Async([&functor, - &row_id_to_grad_row_offset, - &grad_data, - row_numel, - start, - end]() { + fs.push_back(phi::Async([&functor, + &row_id_to_grad_row_offset, + &grad_data, + row_numel, + start, + end]() { for (int64_t row_id = start; row_id < end; ++row_id) { auto iter = row_id_to_grad_row_offset.find(row_id); if (iter != row_id_to_grad_row_offset.end()) {