提交 295b3c80 编写于 作者: D Derek Murray 提交者: TensorFlower Gardener

Automated rollback of commit c9bdd393

PiperOrigin-RevId: 215607038
上级 3d76a830
...@@ -16,8 +16,10 @@ limitations under the License. ...@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_op_registry.h"
...@@ -25,11 +27,13 @@ limitations under the License. ...@@ -25,11 +27,13 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
......
...@@ -29,7 +29,6 @@ limitations under the License. ...@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -406,10 +405,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -406,10 +405,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_ = runner_thread_.reset(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); {}, "runner_thread",
runner_thread_->Schedule( std::bind(&Iterator::RunnerThread, this, ctx_copy)));
std::bind(&Iterator::RunnerThread, this, ctx_copy));
} }
} }
...@@ -662,7 +660,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -662,7 +660,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_; std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results. // Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
}; };
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -127,10 +126,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { ...@@ -127,10 +126,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!optimize_thread_) { if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
optimize_thread_ = optimize_thread_.reset(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread"); {}, "optimize_thread",
optimize_thread_->Schedule( [this, new_ctx]() { OptimizeThread(new_ctx); }));
[this, new_ctx]() { OptimizeThread(new_ctx); });
} }
return Status::OK(); return Status::OK();
} }
...@@ -169,7 +167,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { ...@@ -169,7 +167,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex mu_; mutex mu_;
condition_variable cond_var_; condition_variable cond_var_;
std::shared_ptr<model::Model> model_; std::shared_ptr<model::Model> model_;
std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_); std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
}; };
......
...@@ -26,7 +26,6 @@ limitations under the License. ...@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -482,10 +481,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -482,10 +481,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads()); worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) { for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back( worker_threads_.emplace_back(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); {}, "worker_thread",
worker_threads_.back()->Schedule( [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
[this, new_ctx, i]() { WorkerThread(new_ctx, i); });
} }
} }
return Status::OK(); return Status::OK();
...@@ -582,10 +580,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -582,10 +580,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
} }
workers_[i].SetInputs(s, std::move(args)); workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back( worker_threads_.emplace_back(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread")); {}, "worker_thread",
worker_threads_.back()->Schedule( [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
[this, new_ctx, i]() { WorkerThread(new_ctx, i); });
if (i < dataset()->cycle_length_) { if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i); interleave_indices_.push_back(i);
} else { } else {
...@@ -1050,8 +1047,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ...@@ -1050,8 +1047,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the // The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated. // threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads. // TODO(b/65178177): Avoid allocating additional threads.
std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_ std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
GUARDED_BY(mu_);
}; };
const DatasetBase* const input_; const DatasetBase* const input_;
...@@ -1393,10 +1389,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { ...@@ -1393,10 +1389,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_ = runner_thread_.reset(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); {}, "runner_thread",
runner_thread_->Schedule( [this, new_ctx]() { RunnerThread(new_ctx); }));
[this, new_ctx]() { RunnerThread(new_ctx); });
} }
} }
...@@ -1650,7 +1645,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { ...@@ -1650,7 +1645,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0; int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_; std::unique_ptr<thread::ThreadPool> thread_pool_;
std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled. // Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
......
...@@ -22,7 +22,6 @@ limitations under the License. ...@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -181,10 +180,9 @@ class ParallelMapIterator : public DatasetBaseIterator { ...@@ -181,10 +180,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) { EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_ = runner_thread_.reset(ctx->env()->StartThread(
MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread"); {}, "runner_thread",
runner_thread_->Schedule( std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
} }
} }
...@@ -332,7 +330,7 @@ class ParallelMapIterator : public DatasetBaseIterator { ...@@ -332,7 +330,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results. // Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_); GUARDED_BY(*mu_);
std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_); std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false; bool cancelled_ GUARDED_BY(*mu_) = false;
}; };
......
...@@ -22,7 +22,6 @@ limitations under the License. ...@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
...@@ -257,11 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ...@@ -257,11 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx) Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) { if (!prefetch_thread_) {
prefetch_thread_ =
MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread");
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
prefetch_thread_->Schedule( prefetch_thread_.reset(ctx->env()->StartThread(
[this, new_ctx]() { PrefetchThread(new_ctx); }); {}, "prefetch_thread",
[this, new_ctx]() { PrefetchThread(new_ctx); }));
} }
return Status::OK(); return Status::OK();
} }
...@@ -365,7 +363,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ...@@ -365,7 +363,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string prefix_end_; string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_); std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false;
bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
}; };
......
...@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public: public:
explicit ToTFRecordOp(OpKernelConstruction* ctx) explicit ToTFRecordOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx), : AsyncOpKernel(ctx),
background_worker_( thread_pool_(new thread::ThreadPool(
ctx->env(), ctx->env(), ThreadOptions(),
strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) { strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())),
} 1 /* num_threads */, false /* low_latency_hint */)) {}
template <typename T> template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx, Status ParseScalarArgument(OpKernelContext* ctx,
...@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an // The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the // inter-op thread pool thread, so we issue the call from the
// owned thread pool. // owned thread pool.
background_worker_.Schedule([this, ctx, done]() { thread_pool_->Schedule([this, ctx, done]() {
string filename; string filename;
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done); ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done);
...@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel { ...@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
} }
private: private:
BackgroundWorker background_worker_; std::unique_ptr<thread::ThreadPool> thread_pool_;
}; };
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册