提交 c9bdd393 编写于 作者: D Derek Murray 提交者: TensorFlower Gardener

[tf.data] Switch background threads to use `BackgroundWorker`.

PiperOrigin-RevId: 215579950
上级 6b0d1ec9
...@@ -16,10 +16,8 @@ limitations under the License. ...@@ -16,10 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_op_registry.h"
...@@ -27,13 +25,11 @@ limitations under the License. ...@@ -27,13 +25,11 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
......
...@@ -29,6 +29,7 @@ limitations under the License. ...@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -405,9 +406,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -405,9 +406,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread( runner_thread_ =
{}, "runner_thread", MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
std::bind(&Iterator::RunnerThread, this, ctx_copy))); runner_thread_->Schedule(
std::bind(&Iterator::RunnerThread, this, ctx_copy));
} }
} }
...@@ -660,7 +662,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -660,7 +662,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_; std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results. // Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
}; };
......
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { ...@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!optimize_thread_) { if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
optimize_thread_.reset(ctx->env()->StartThread( optimize_thread_ =
{}, "optimize_thread", MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread");
[this, new_ctx]() { OptimizeThread(new_ctx); })); optimize_thread_->Schedule(
[this, new_ctx]() { OptimizeThread(new_ctx); });
} }
return Status::OK(); return Status::OK();
} }
...@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { ...@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex mu_; mutex mu_;
condition_variable cond_var_; condition_variable cond_var_;
std::shared_ptr<model::Model> model_; std::shared_ptr<model::Model> model_;
std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_); std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
}; };
......
...@@ -26,6 +26,7 @@ limitations under the License. ...@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads()); worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) { for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread( worker_threads_.emplace_back(
{}, "worker_thread", MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
[this, new_ctx, i]() { WorkerThread(new_ctx, i); })); worker_threads_.back()->Schedule(
[this, new_ctx, i]() { WorkerThread(new_ctx, i); });
} }
} }
return Status::OK(); return Status::OK();
...@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
} }
workers_[i].SetInputs(s, std::move(args)); workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread( worker_threads_.emplace_back(
{}, "worker_thread", MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
[this, new_ctx, i]() { WorkerThread(new_ctx, i); })); worker_threads_.back()->Schedule(
[this, new_ctx, i]() { WorkerThread(new_ctx, i); });
if (i < dataset()->cycle_length_) { if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i); interleave_indices_.push_back(i);
} else { } else {
...@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the // The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated. // threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads. // TODO(b/65178177): Avoid allocating additional threads.
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_
GUARDED_BY(mu_);
}; };
const DatasetBase* const input_; const DatasetBase* const input_;
...@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { ...@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread( runner_thread_ =
{}, "runner_thread", MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
[this, new_ctx]() { RunnerThread(new_ctx); })); runner_thread_->Schedule(
[this, new_ctx]() { RunnerThread(new_ctx); });
} }
} }
...@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { ...@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0; int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_; std::unique_ptr<thread::ThreadPool> thread_pool_;
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled. // Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
......
...@@ -22,6 +22,7 @@ limitations under the License. ...@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -180,9 +181,10 @@ class ParallelMapIterator : public DatasetBaseIterator { ...@@ -180,9 +181,10 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread( runner_thread_ =
{}, "runner_thread", MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); runner_thread_->Schedule(
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
} }
} }
...@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator { ...@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results. // Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_); GUARDED_BY(*mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
}; };
......
...@@ -22,6 +22,7 @@ limitations under the License. ...@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ...@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx) Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) { if (!prefetch_thread_) {
prefetch_thread_ =
MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread");
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
prefetch_thread_.reset(ctx->env()->StartThread( prefetch_thread_->Schedule(
{}, "prefetch_thread", [this, new_ctx]() { PrefetchThread(new_ctx); });
[this, new_ctx]() { PrefetchThread(new_ctx); }));
} }
return Status::OK(); return Status::OK();
} }
...@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ...@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string prefix_end_; string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false;
bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
}; };
......
...@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public: public:
explicit ToTFRecordOp(OpKernelConstruction* ctx) explicit ToTFRecordOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx), : AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool( background_worker_(
ctx->env(), ThreadOptions(), ctx->env(),
strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())), strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) {
1 /* num_threads */, false /* low_latency_hint */)) {} }
template <typename T> template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx, Status ParseScalarArgument(OpKernelContext* ctx,
...@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an // The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the // inter-op thread pool thread, so we issue the call from the
// owned thread pool. // owned thread pool.
thread_pool_->Schedule([this, ctx, done]() { background_worker_.Schedule([this, ctx, done]() {
string filename; string filename;
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done); ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done);
...@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
} }
private: private:
std::unique_ptr<thread::ThreadPool> thread_pool_; BackgroundWorker background_worker_;
}; };
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册