提交 017498bc 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

tensorflow: support usage of eigen thread pool

Use eigen ThreadPool instead of tensorflow one if TENSORFLOW_USE_EIGEN_THREADPOOL is defined. This will allow to switch to the new non-blocking ThreadPool.
Change: 119512280
上级 eb161ecd
......@@ -164,7 +164,6 @@ TEST_F(DataByExampleTest, VisitUnavailable) {
signal(&updated_data);
});
wait(&completed_visit);
EXPECT_FALSE(thread_pool.HasPendingClosures());
EXPECT_TRUE(errors::IsUnavailable(status));
}
......
......@@ -15,6 +15,16 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#else
#include <deque>
#include <thread>
#include <vector>
#endif
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
......@@ -24,26 +34,97 @@ limitations under the License.
namespace tensorflow {
namespace thread {
struct ThreadPool::Waiter {
condition_variable cv;
bool ready;
#ifdef TENSORFLOW_USE_EIGEN_THREADPOOL
struct EigenEnvironment {
typedef Thread EnvThread;
struct Task {
std::function<void()> f;
uint64 trace_id;
};
Env* const env_;
const ThreadOptions thread_options_;
const string name_;
EigenEnvironment(Env* env, const ThreadOptions& thread_options,
const string& name)
: env_(env), thread_options_(thread_options), name_(name) {}
EnvThread* CreateThread(std::function<void()> f) {
return env_->StartThread(thread_options_, name_, [=]() {
// Set the processor flag to flush denormals to zero
port::ScopedFlushDenormal flush;
f();
});
}
Task CreateTask(std::function<void()> f) {
uint64 id = 0;
if (port::Tracing::IsActive()) {
id = port::Tracing::UniqueId();
port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
id);
}
return Task{std::move(f), id};
}
void ExecuteTask(const Task& t) {
if (t.trace_id != 0) {
port::Tracing::ScopedActivity region(
port::Tracing::EventCategory::kRunClosure, t.trace_id);
t.f();
} else {
t.f();
}
}
};
struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
int num_threads)
: Eigen::ThreadPoolTempl<EigenEnvironment>(
num_threads, EigenEnvironment(env, thread_options, name)) {}
};
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
: ThreadPool(env, ThreadOptions(), name, num_threads) {}
#else
ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
struct ThreadPool::Impl {
Impl(Env* env, const ThreadOptions& thread_options, const string& name,
int num_threads);
~Impl();
void Schedule(std::function<void()> fn);
private:
struct Waiter {
condition_variable cv;
bool ready;
};
struct Task {
std::function<void()> fn;
uint64 id;
};
void WorkerLoop();
const string name_;
mutex mu_;
std::vector<Thread*> threads_; // All threads
std::vector<Waiter*> waiters_; // Stack of waiting threads.
std::deque<Task> pending_; // Queue of pending work
};
ThreadPool::Impl::Impl(Env* env, const ThreadOptions& thread_options,
const string& name, int num_threads)
: name_(name) {
CHECK_GE(num_threads, 1);
string name_prefix = "tf_" + name_;
for (int i = 0; i < num_threads; i++) {
threads_.push_back(env->StartThread(thread_options, name_prefix,
[this]() { WorkerLoop(); }));
threads_.push_back(
env->StartThread(thread_options, name, [this]() { WorkerLoop(); }));
}
}
ThreadPool::~ThreadPool() {
ThreadPool::Impl::~Impl() {
{
// Wait for all work to get done.
mutex_lock l(mu_);
......@@ -66,13 +147,7 @@ ThreadPool::~ThreadPool() {
}
}
bool ThreadPool::HasPendingClosures() const {
mutex_lock l(mu_);
return pending_.size() != 0;
}
void ThreadPool::Schedule(std::function<void()> fn) {
CHECK(fn != nullptr);
void ThreadPool::Impl::Schedule(std::function<void()> fn) {
uint64 id = 0;
if (port::Tracing::IsActive()) {
id = port::Tracing::UniqueId();
......@@ -90,7 +165,7 @@ void ThreadPool::Schedule(std::function<void()> fn) {
}
}
void ThreadPool::WorkerLoop() {
void ThreadPool::Impl::WorkerLoop() {
// Set the processor flag to flush denormals to zero
port::ScopedFlushDenormal flush;
......@@ -107,22 +182,40 @@ void ThreadPool::WorkerLoop() {
}
}
// Pick up pending work
Item item = pending_.front();
Task t = pending_.front();
pending_.pop_front();
if (item.fn == nullptr) {
if (t.fn == nullptr) {
break;
}
mu_.unlock();
if (item.id != 0) {
if (t.id != 0) {
port::Tracing::ScopedActivity region(
port::Tracing::EventCategory::kRunClosure, item.id);
item.fn();
port::Tracing::EventCategory::kRunClosure, t.id);
t.fn();
} else {
item.fn();
t.fn();
}
mu_.lock();
}
}
#endif
ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
: ThreadPool(env, ThreadOptions(), name, num_threads) {}
ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
const string& name, int num_threads) {
CHECK_GE(num_threads, 1);
impl_.reset(
new ThreadPool::Impl(env, thread_options, "tf_" + name, num_threads));
}
ThreadPool::~ThreadPool() {}
void ThreadPool::Schedule(std::function<void()> fn) {
CHECK(fn != nullptr);
impl_->Schedule(std::move(fn));
}
} // namespace thread
} // namespace tensorflow
......@@ -16,13 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
#include <deque>
#include <functional>
#include <thread>
#include <vector>
#include <memory>
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
......@@ -45,28 +42,15 @@ class ThreadPool {
// Wait until all scheduled work has finished and then destroy the
// set of threads.
virtual ~ThreadPool();
~ThreadPool();
// Schedule fn() for execution in the pool of threads.
virtual void Schedule(std::function<void()> fn);
void Schedule(std::function<void()> fn);
virtual bool HasPendingClosures() const;
struct Impl;
private:
struct Waiter;
struct Item {
std::function<void()> fn;
uint64 id;
};
void WorkerLoop();
const string name_;
mutable mutex mu_;
std::vector<Thread*> threads_; // All threads
std::vector<Waiter*> waiters_; // Stack of waiting threads.
std::deque<Item> pending_; // Queue of pending work
std::unique_ptr<Impl> impl_;
TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
};
......
......@@ -18,6 +18,7 @@ limitations under the License.
#include <atomic>
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册