未验证 提交 3ca7328f 编写于 作者: P PuQing 提交者: GitHub

[PHI decoupling] move "thread pool" from fluid to phi (#48075)

* move threadpool

fix cmake

* fix make
上级 468f8815
...@@ -225,10 +225,6 @@ cc_test( ...@@ -225,10 +225,6 @@ cc_test(
SRCS reader_test.cc SRCS reader_test.cc
DEPS reader) DEPS reader)
cc_library(
threadpool
SRCS threadpool.cc
DEPS enforce)
cc_test( cc_test(
threadpool_test threadpool_test
SRCS threadpool_test.cc SRCS threadpool_test.cc
......
...@@ -155,7 +155,7 @@ void Scope::DeleteScope(Scope* scope) const { ...@@ -155,7 +155,7 @@ void Scope::DeleteScope(Scope* scope) const {
if (FLAGS_benchmark || FLAGS_eager_delete_scope) { if (FLAGS_benchmark || FLAGS_eager_delete_scope) {
delete scope; delete scope;
} else { } else {
Async([scope] { delete scope; }); phi::Async([scope] { delete scope; });
} }
} }
} }
......
...@@ -14,138 +14,16 @@ limitations under the License. */ ...@@ -14,138 +14,16 @@ limitations under the License. */
#pragma once #pragma once
#include <condition_variable> // NOLINT #include "paddle/phi/core/threadpool.h"
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include <queue>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ExceptionHandler { using ExceptionHandler = phi::ExceptionHandler;
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
PADDLE_THROW(platform::errors::Fatal(
"The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception."
"The exception is:\n %s.",
ex->what()));
}
}
};
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
class ThreadPool {
public:
explicit ThreadPool(int num_threads);
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
// Returns the singleton of ThreadPool.
static ThreadPool* GetInstance();
~ThreadPool();
// Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call
// std::future::wait().
template <typename Callback>
std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
}
template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) {
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try {
fn();
} catch (platform::EnforceNotMet& ex) {
return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex));
} catch (const std::exception& e) {
PADDLE_THROW(platform::errors::Fatal(
"Unexpected exception is catched in thread pool. All "
"throwable exception in Paddle should be an EnforceNotMet."
"The exception is:\n %s.",
e.what()));
}
return nullptr;
});
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
{
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW(platform::errors::Unavailable(
"Task is enqueued into stopped ThreadPool."));
}
tasks_.push(std::move(task));
}
scheduled_.notify_one();
return f;
}
private:
DISABLE_COPY_AND_ASSIGN(ThreadPool);
// The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue.
void TaskLoop();
// Init is called by GetInstance.
static void Init();
private:
static std::unique_ptr<ThreadPool> threadpool_;
static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_;
std::queue<Task> tasks_;
std::mutex mutex_;
bool running_;
std::condition_variable scheduled_;
};
class ThreadPoolIO : ThreadPool {
public:
static ThreadPool* GetInstanceIO();
static void InitIO();
private:
// NOTE: threadpool in base will be inhereted here.
static std::unique_ptr<ThreadPool> io_threadpool_;
static std::once_flag io_init_flag_;
};
// Run a function asynchronously. using ThreadPool = phi::ThreadPool;
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
template <typename Callback>
std::future<void> Async(Callback callback) {
return ThreadPool::GetInstance()->Run(callback);
}
template <typename Callback> using ThreadPoolIO = phi::ThreadPoolIO;
std::future<void> AsyncIO(Callback callback) {
return ThreadPoolIO::GetInstanceIO()->Run(callback);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ void do_sum(std::vector<std::future<void>>* fs, ...@@ -26,7 +26,7 @@ void do_sum(std::vector<std::future<void>>* fs,
int cnt) { int cnt) {
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
std::lock_guard<std::mutex> l(*mu); std::lock_guard<std::mutex> l(*mu);
fs->push_back(framework::Async([sum]() { sum->fetch_add(1); })); fs->push_back(phi::Async([sum]() { sum->fetch_add(1); }));
} }
} }
......
...@@ -43,6 +43,10 @@ cc_library( ...@@ -43,6 +43,10 @@ cc_library(
lod_utils lod_utils
SRCS lod_utils.cc SRCS lod_utils.cc
DEPS phi_enforce) DEPS phi_enforce)
cc_library(
threadpool
SRCS threadpool.cc
DEPS phi_enforce)
cc_library( cc_library(
dense_tensor dense_tensor
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/framework/threadpool.h" #include "paddle/phi/core/threadpool.h"
#include <thread> #include <thread>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
DECLARE_int32(dist_threadpool_size);
DEFINE_int32(io_threadpool_size, DEFINE_int32(io_threadpool_size,
100, 100,
"number of threads used for doing IO, default 100"); "number of threads used for doing IO, default 100");
DECLARE_int32(dist_threadpool_size); namespace phi {
namespace paddle {
namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag_; std::once_flag ThreadPool::init_flag_;
...@@ -47,7 +46,7 @@ void ThreadPool::Init() { ...@@ -47,7 +46,7 @@ void ThreadPool::Init() {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
num_threads, num_threads,
0, 0,
platform::errors::InvalidArgument("The number of threads is 0.")); phi::errors::InvalidArgument("The number of threads is 0."));
threadpool_.reset(new ThreadPool(num_threads)); threadpool_.reset(new ThreadPool(num_threads));
} }
} }
...@@ -88,8 +87,8 @@ void ThreadPool::TaskLoop() { ...@@ -88,8 +87,8 @@ void ThreadPool::TaskLoop() {
} }
if (tasks_.empty()) { if (tasks_.empty()) {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(
"Current thread has no task to Run.")); phi::errors::Unavailable("Current thread has no task to Run."));
} }
// pop a task from the task queue // pop a task from the task queue
...@@ -115,6 +114,4 @@ void ThreadPoolIO::InitIO() { ...@@ -115,6 +114,4 @@ void ThreadPoolIO::InitIO() {
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
} }
} }
} // namespace phi
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include <queue>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace phi {
struct ExceptionHandler {
mutable std::future<std::unique_ptr<phi::enforce::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<phi::enforce::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
PADDLE_THROW(phi::errors::Fatal(
"The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception."
"The exception is:\n %s.",
ex->what()));
}
}
};
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
class ThreadPool {
public:
explicit ThreadPool(int num_threads);
using Task =
std::packaged_task<std::unique_ptr<phi::enforce::EnforceNotMet>()>;
// Returns the singleton of ThreadPool.
static ThreadPool* GetInstance();
~ThreadPool();
// Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call
// std::future::wait().
template <typename Callback>
std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
}
template <typename Callback>
std::future<std::unique_ptr<phi::enforce::EnforceNotMet>> RunAndGetException(
Callback fn) {
Task task([fn]() -> std::unique_ptr<phi::enforce::EnforceNotMet> {
try {
fn();
} catch (phi::enforce::EnforceNotMet& ex) {
return std::unique_ptr<phi::enforce::EnforceNotMet>(
new phi::enforce::EnforceNotMet(ex));
} catch (const std::exception& e) {
PADDLE_THROW(phi::errors::Fatal(
"Unexpected exception is catched in thread pool. All "
"throwable exception in Paddle should be an EnforceNotMet."
"The exception is:\n %s.",
e.what()));
}
return nullptr;
});
std::future<std::unique_ptr<phi::enforce::EnforceNotMet>> f =
task.get_future();
{
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW(phi::errors::Unavailable(
"Task is enqueued into stopped ThreadPool."));
}
tasks_.push(std::move(task));
}
scheduled_.notify_one();
return f;
}
private:
DISABLE_COPY_AND_ASSIGN(ThreadPool);
// The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue.
void TaskLoop();
// Init is called by GetInstance.
static void Init();
private:
static std::unique_ptr<ThreadPool> threadpool_;
static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_;
std::queue<Task> tasks_;
std::mutex mutex_;
bool running_;
std::condition_variable scheduled_;
};
class ThreadPoolIO : ThreadPool {
public:
static ThreadPool* GetInstanceIO();
static void InitIO();
private:
// NOTE: threadpool in base will be inhereted here.
static std::unique_ptr<ThreadPool> io_threadpool_;
static std::once_flag io_init_flag_;
};
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
template <typename Callback>
std::future<void> Async(Callback callback) {
return ThreadPool::GetInstance()->Run(callback);
}
template <typename Callback>
std::future<void> AsyncIO(Callback callback) {
return ThreadPoolIO::GetInstanceIO()->Run(callback);
}
} // namespace phi
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#include "paddle/phi/kernels/selected_rows/adam_kernel.h" #include "paddle/phi/kernels/selected_rows/adam_kernel.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/threadpool.h"
#include "paddle/phi/kernels/funcs/adam_functors.h" #include "paddle/phi/kernels/funcs/adam_functors.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
...@@ -201,12 +201,12 @@ void AdamDenseParamSparseGradKernel( ...@@ -201,12 +201,12 @@ void AdamDenseParamSparseGradKernel(
if (end > static_cast<int64_t>(param_row_count)) { if (end > static_cast<int64_t>(param_row_count)) {
end = static_cast<int64_t>(param_row_count); end = static_cast<int64_t>(param_row_count);
} }
fs.push_back(paddle::framework::Async([&functor, fs.push_back(phi::Async([&functor,
&row_id_to_grad_row_offset, &row_id_to_grad_row_offset,
&grad_data, &grad_data,
row_numel, row_numel,
start, start,
end]() { end]() {
for (int64_t row_id = start; row_id < end; ++row_id) { for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id); auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) { if (iter != row_id_to_grad_row_offset.end()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册