diff --git a/paddle/phi/backends/callback_manager.cc b/paddle/phi/backends/callback_manager.cc index e21e8502d8f8c43e7484982354c4ea69253a195f..4a958ef73bfc67d73bcf73f7d50d224beb6b8ae4 100644 --- a/paddle/phi/backends/callback_manager.cc +++ b/paddle/phi/backends/callback_manager.cc @@ -16,16 +16,18 @@ #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/enforce.h" +#include + namespace phi { CallbackManager::CallbackManager(stream::Stream *stream) - : stream_(stream), thread_pool_(1) {} + : stream_(stream), thread_pool_(new ::ThreadPool(1)) {} void CallbackManager::AddCallback(std::function callback) const { auto *callback_func = new std::function(std::move(callback)); auto *func = new std::function([this, callback_func] { std::lock_guard lock(mtx_); - last_future_ = thread_pool_.enqueue([callback_func] { + last_future_ = thread_pool_->enqueue([callback_func] { std::unique_ptr> releaser(callback_func); (*callback_func)(); }); diff --git a/paddle/phi/backends/callback_manager.h b/paddle/phi/backends/callback_manager.h index 359958b7c93e2c4041532a377f35836ca8ae89bc..2bb26745288dfebf7cb669e631b691c490fcbfd6 100644 --- a/paddle/phi/backends/callback_manager.h +++ b/paddle/phi/backends/callback_manager.h @@ -14,8 +14,6 @@ #pragma once -#include - #ifdef PADDLE_WITH_CUDA #include #include @@ -30,6 +28,8 @@ #include #include // NOLINT +class ThreadPool; + namespace phi { namespace stream { @@ -50,7 +50,7 @@ class CallbackManager { private: stream::Stream* stream_; - mutable ::ThreadPool thread_pool_; + mutable std::shared_ptr<::ThreadPool> thread_pool_; mutable std::mutex mtx_; mutable std::future last_future_; }; diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index b4964708dfb9797c75e6f69ccb8bae6853b424a9..8cc6e498068fa65d697f6f002bec17b075b42866 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -14,6 +14,8 @@ #pragma once #ifdef PADDLE_WITH_CUSTOM_DEVICE +#include + #include "paddle/phi/backends/event.h" #include "paddle/phi/backends/stream.h"