diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index f790b4bf07fe3870ab946d49c9586092f4f24939..fdf8aebc3abcc56cc58865699e0977c8850f8229 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -140,7 +140,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { IteratorContext::Params params(ctx); { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); + TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx)); params.model = model_; int64 now_nanos = EnvTime::NowNanos(); RecordInput(now_nanos); @@ -175,18 +175,16 @@ class ModelDatasetOp::Dataset : public DatasetBase { } private: - Status EnsureOptimizeThreadStarted(IteratorContext* ctx) + Status EnsureModelThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!model_thread_) { - std::shared_ptr new_ctx = - std::make_shared(*ctx); - model_thread_ = ctx->StartThread( - "tf_data_model", [this, new_ctx]() { ModelThread(new_ctx); }); + model_thread_ = + ctx->StartThread("tf_data_model", [this]() { ModelThread(); }); } return Status::OK(); } - void ModelThread(const std::shared_ptr& ctx) { + void ModelThread() { int64 last_optimization_ms = 0; int64 optimization_period_ms = 10; int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 568c01646de4f2867e2b3896d5f828842886dd1e..a65a9d79340d626e56aff1e3b61e30732afdce98 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -330,6 +330,14 @@ def replicate(dataset, devices): return datasets with ops.colocate_with(dataset._variant_tensor): + # We apply options before replicating the dataset because options are + # currently not automatically preserved through dataset serialization and + # thus an explicit application of options here is needed to avoid losing + # `dataset` options. + # + # TODO(b/147325552): Propagating options to C++ upon their setting would + # allow us to preserve the options across both variant and GraphDef based + # serialization, avoiding the need to explicitly apply options here. dataset = dataset._apply_options() policy = dataset.options().experimental_external_state_policy if policy is None: