提交 9158b1b8 编写于 作者: D Derek Murray 提交者: TensorFlower Gardener

[tf.data] Move captured function instantiation to iterator initialization time.

Previously, a function instantiation error (e.g. in `Dataset.map()`) would lead
to an error in each GetNext() call that attempted to use the function. Moving this
to iterator instantiation time has the benefit that the error will be reported
once when the initialization op is executed, which has a more helpful stack
trace, since it should not be conflated with other potential op failures.

PiperOrigin-RevId: 209633511
上级 e28f9da8
......@@ -555,6 +555,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
next_handle_++;
}
}
if (options.create_kernels_eagerly) {
Item* item;
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
}
return Status::OK();
}
......
......@@ -490,6 +490,11 @@ class FunctionLibraryRuntime {
// Instantiates the function using an executor of the given type. If empty,
// the default TensorFlow executor will be used.
string executor_type;
// If true, the runtime will attempt to create kernels for the function at
// instantiation time, rather than on the first run. This can be used to
// surface errors earlier.
bool create_kernels_eagerly = false;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
......
......@@ -172,31 +172,17 @@ class BorrowedArgsCallFrame : public CallFrameBase {
} // namespace
Status CapturedFunction::MaybeInstantiate(
IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
mutex_lock l(mu_);
Status CapturedFunction::GetHandle(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle) {
tf_shared_lock l(mu_);
if (lib_ == nullptr) {
// The context's runtime will be used for all subsequent calls.
lib_ = ctx->lib();
DCHECK(f_handle_ == kInvalidHandle);
FunctionLibraryRuntime::InstantiateOptions inst_opts;
inst_opts.overlay_lib = ctx->function_library().get();
inst_opts.state_handle = std::to_string(random::New64());
TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
if (fbody == nullptr) {
return errors::Internal("Failed to instantiate function body.");
}
ret_types_ = fbody->ret_types;
} else {
// TODO(mrry): Consider moving this under a shared lock, as it is
// the common case.
if (ctx->lib() != lib_) {
return errors::Internal(
"Captured function was called with a different "
"FunctionLibraryRuntime*, which is not permitted.");
}
return errors::Internal("Captured function \"", func_.name(),
"\" was called before it was instantiated.");
}
if (ctx->lib() != lib_) {
return errors::Internal("Captured function \"", func_.name(),
"\" was called with a different "
"FunctionLibraryRuntime*, which is not permitted.");
}
*out_handle = f_handle_;
return Status::OK();
......@@ -205,7 +191,7 @@ Status CapturedFunction::MaybeInstantiate(
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
......@@ -242,7 +228,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
......@@ -277,9 +263,30 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
}
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
FunctionLibraryRuntime::Handle unused_handle;
TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
mutex_lock l(mu_);
if (lib_ == nullptr) {
// The context's runtime will be used for all subsequent calls.
lib_ = ctx->lib();
DCHECK(f_handle_ == kInvalidHandle);
FunctionLibraryRuntime::InstantiateOptions inst_opts;
inst_opts.overlay_lib = ctx->function_library().get();
inst_opts.state_handle = std::to_string(random::New64());
inst_opts.create_kernels_eagerly = true;
Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
TF_RETURN_IF_ERROR(s);
const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
if (fbody == nullptr) {
return errors::Internal("Failed to instantiate function body.");
}
ret_types_ = fbody->ret_types;
} else {
if (ctx->lib() != lib_) {
return errors::Internal(
"Captured function was called with a different "
"FunctionLibraryRuntime*, which is not permitted.");
}
}
if (captured_runner_ == nullptr) {
captured_runner_ = *ctx->runner();
}
......@@ -343,7 +350,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
FunctionLibraryRuntime::Handle handle;
Status s = MaybeInstantiate(ctx, &handle);
Status s = GetHandle(ctx, &handle);
if (!s.ok()) {
done(s);
return;
......
......@@ -116,8 +116,8 @@ class CapturedFunction {
CapturedFunction(const NameAttrList& func,
std::vector<Tensor> captured_inputs);
Status MaybeInstantiate(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle);
Status GetHandle(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle);
mutex mu_;
const NameAttrList func_;
......
......@@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<FilterDatasetBase>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
......@@ -80,20 +81,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
}
}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!initialized_) {
TF_RETURN_IF_ERROR(
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
// Explicitly instantiate the finalize function here so that
// we can invoke it in the destructor.
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
initialized_ = true;
}
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
......@@ -121,7 +122,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
......
......@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/captured_function.h"
namespace tensorflow {
......
......@@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->captured_finalize_func_->Instantiate(ctx));
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->captured_window_size_func_->Instantiate(ctx));
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
......
......@@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase {
bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
if (lib_ != nullptr) {
ctx->set_lib(lib_);
}
CHECK_NOTNULL(lib_);
ctx->set_lib(lib_);
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
} else {
return errors::FailedPrecondition(
......@@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase {
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
std::unique_ptr<IteratorBase> iterator;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
......@@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase {
return lib_def_;
}
FunctionLibraryRuntime* function_library_runtime() { return lib_; }
// Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) {
......@@ -612,8 +615,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
core::ScopedUnref unref(iterator_resource);
std::unique_ptr<IteratorBase> iterator;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(iterator_resource->function_library_runtime());
OP_REQUIRES_OK(
ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
......@@ -837,8 +842,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
std::unique_ptr<IteratorBase> iter;
IteratorContext iter_ctx(ctx);
iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
......
......@@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -127,7 +127,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -142,8 +142,15 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
&input_impl_);
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
params.stats_aggregator_getter = ctx->stats_aggregator_getter();
params.lib = ctx->lib();
params.function_library = dataset()->flib_def_;
params.allocator_getter = ctx->allocator_getter();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(params), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -251,7 +251,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
// It is implemented so that it matches the deterministic interleave
......
......@@ -88,6 +88,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
auto init_func = [this](IteratorContext* ctx) {
return captured_func_->Instantiate(ctx);
};
auto map_func = [this](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
......@@ -97,7 +101,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return NewParallelMapIterator(
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
std::move(map_func), num_parallel_calls_);
std::move(init_func), std::move(map_func), num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
......
......@@ -26,10 +26,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls)
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params),
input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
num_parallel_calls_(num_parallel_calls) {}
......@@ -50,7 +52,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
TF_RETURN_IF_ERROR(init_func_(ctx));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
......@@ -285,6 +292,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
......@@ -311,8 +319,18 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) {
return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
params, input_dataset, std::move(map_func), num_parallel_calls));
return NewParallelMapIterator(params, input_dataset, nullptr,
std::move(map_func), num_parallel_calls);
}
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
return std::unique_ptr<IteratorBase>(
new ParallelMapIterator(params, input_dataset, std::move(init_func),
std::move(map_func), num_parallel_calls));
}
} // namespace tensorflow
......@@ -33,7 +33,15 @@ using ParallelMapIteratorFunction =
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
// `input_dataset` using the given degree of parallelism.
// `input_dataset` using the given degree of parallelism. `init_func` (if
// specified) will be executed when the iterator is initialized (see
// `IteratorBase::Initialize()`) and enables the user to specify error checking
// logic that can fail early.
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
......
......@@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class ForeverIterator : public DatasetIterator<Dataset> {
public:
explicit ForeverIterator(const Params& params)
: DatasetIterator<Dataset>(params), input_impl_(nullptr) {}
: DatasetIterator<Dataset>(params),
input_impl_(nullptr),
first_call_(true) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
do {
bool first_call = false;
if (!input_impl_) {
first_call = true;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
if (first_call_ && *end_of_sequence) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
input_impl_.reset();
return Status::OK();
}
first_call_ = false;
if (!*end_of_sequence) {
return s;
} else {
input_impl_.reset();
if (first_call) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
return Status::OK();
}
first_call_ = true;
}
} while (true);
}
......@@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_)
if (!first_call_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
else
TF_RETURN_IF_ERROR(
......@@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
first_call_ = true;
} else {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
first_call_ = false;
}
return Status::OK();
}
......@@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool first_call_ GUARDED_BY(mu_);
};
const int64 count_;
......
......@@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_(params.dataset->initial_state_) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
......
......@@ -24,6 +24,7 @@ import warnings
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -31,6 +32,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
......@@ -673,6 +675,36 @@ class MapDatasetTest(test.TestCase):
r"Dataset.map\(\): None."):
_ = dataset.map(lambda x: None)
def testBrokenFunctionErrorOnInitialization(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
def broken_function(_):
"""A function deliberately designed to fail on instantiation."""
value = []
tensor_value = attr_value_pb2.AttrValue()
tensor_value.tensor.CopyFrom(
tensor_util.make_tensor_proto(
value, dtype=dtypes.float32, shape=[0], verify_shape=False))
dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
# Create a "Const" op with a `tf.float32` value and a `tf.int32` type
# attr.
const_tensor = ops.get_default_graph().create_op(
"Const", [], [dtypes.int32],
attrs={
"value": tensor_value,
"dtype": dtype_value
},
name="BrokenConst").outputs[0]
return const_tensor
dataset = dataset.map(broken_function)
iterator = dataset.make_initializable_iterator()
with self.test_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
sess.run(iterator.initializer)
class MapDatasetBenchmark(test.Benchmark):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册