未验证 提交 d3a43477 编写于 作者: R ronnywang 提交者: GitHub

[custom runtime] clear headers (#40845)

上级 01339433
...@@ -16,16 +16,18 @@ ...@@ -16,16 +16,18 @@
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include <ThreadPool.h>
namespace phi { namespace phi {
CallbackManager::CallbackManager(stream::Stream *stream) CallbackManager::CallbackManager(stream::Stream *stream)
: stream_(stream), thread_pool_(1) {} : stream_(stream), thread_pool_(new ::ThreadPool(1)) {}
void CallbackManager::AddCallback(std::function<void()> callback) const { void CallbackManager::AddCallback(std::function<void()> callback) const {
auto *callback_func = new std::function<void()>(std::move(callback)); auto *callback_func = new std::function<void()>(std::move(callback));
auto *func = new std::function<void()>([this, callback_func] { auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_); std::lock_guard<std::mutex> lock(mtx_);
last_future_ = thread_pool_.enqueue([callback_func] { last_future_ = thread_pool_->enqueue([callback_func] {
std::unique_ptr<std::function<void()>> releaser(callback_func); std::unique_ptr<std::function<void()>> releaser(callback_func);
(*callback_func)(); (*callback_func)();
}); });
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#include <ThreadPool.h>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -30,6 +28,8 @@ ...@@ -30,6 +28,8 @@
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
class ThreadPool;
namespace phi { namespace phi {
namespace stream { namespace stream {
...@@ -50,7 +50,7 @@ class CallbackManager { ...@@ -50,7 +50,7 @@ class CallbackManager {
private: private:
stream::Stream* stream_; stream::Stream* stream_;
mutable ::ThreadPool thread_pool_; mutable std::shared_ptr<::ThreadPool> thread_pool_;
mutable std::mutex mtx_; mutable std::mutex mtx_;
mutable std::future<void> last_future_; mutable std::future<void> last_future_;
}; };
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <vector>
#include "paddle/phi/backends/event.h" #include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h" #include "paddle/phi/backends/stream.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册