diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 21d386b728fc0f0105733ce257446e23b6031a19..98d5c5b50092196e042c51260aadc7b0432b7e52 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -217,11 +217,11 @@ class GraphIOTest(test.TestCase): parse_example_queue_name = "%s/fifo_queue" % name op_nodes = test_util.assert_ops_in_graph({ file_names_name: "Const", - file_name_queue_name: "FIFOQueue", - "%s/read/TFRecordReader" % name: "TFRecordReader", - example_queue_name: "FIFOQueue", - parse_example_queue_name: "FIFOQueue", - name: "QueueDequeueMany" + file_name_queue_name: "FIFOQueueV2", + "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", + example_queue_name: "FIFOQueueV2", + parse_example_queue_name: "FIFOQueueV2", + name: "QueueDequeueManyV2" }, g) self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) self.assertEqual(queue_capacity, @@ -250,10 +250,10 @@ class GraphIOTest(test.TestCase): example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ file_names_name: "Const", - file_name_queue_name: "FIFOQueue", - "%s/read/TFRecordReader" % name: "TFRecordReader", - example_queue_name: "RandomShuffleQueue", - name: "QueueDequeueUpTo", + file_name_queue_name: "FIFOQueueV2", + "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", + example_queue_name: "RandomShuffleQueueV2", + name: "QueueDequeueUpToV2", file_name_queue_limit_name: "VariableV2" }, g) self.assertEqual( @@ -281,10 +281,10 @@ class GraphIOTest(test.TestCase): example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ file_names_name: "Const", - file_name_queue_name: "FIFOQueue", - "%s/read/TFRecordReader" % name: "TFRecordReader", - example_queue_name: "RandomShuffleQueue", - name: "QueueDequeueMany" + file_name_queue_name: "FIFOQueueV2", + "%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2", + example_queue_name: "RandomShuffleQueueV2", + name: "QueueDequeueManyV2" }, g) self.assertEqual( set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0])) @@ -427,10 +427,10 @@ class GraphIOTest(test.TestCase): example_queue_name = "%s/fifo_queue" % name test_util.assert_ops_in_graph({ file_names_name: "Const", - file_name_queue_name: "FIFOQueue", - "%s/read/TextLineReader" % name: "TextLineReader", - example_queue_name: "FIFOQueue", - name: "QueueDequeueUpTo" + file_name_queue_name: "FIFOQueueV2", + "%s/read/TextLineReaderV2" % name: "TextLineReaderV2", + example_queue_name: "FIFOQueueV2", + name: "QueueDequeueUpToV2" }, g) self.assertAllEqual(session.run(inputs), [b"ABC"]) @@ -473,10 +473,10 @@ class GraphIOTest(test.TestCase): example_queue_name = "%s/fifo_queue" % name worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name test_util.assert_ops_in_graph({ - "%s/read/TextLineReader" % name: "TextLineReader", - example_queue_name: "FIFOQueue", - worker_file_name_queue_name: "FIFOQueue", - name: "QueueDequeueUpTo" + "%s/read/TextLineReaderV2" % name: "TextLineReaderV2", + example_queue_name: "FIFOQueueV2", + worker_file_name_queue_name: "FIFOQueueV2", + name: "QueueDequeueUpToV2" }, g) self.assertAllEqual(session.run(inputs), [b"ABC"]) diff --git a/tensorflow/contrib/training/python/training/feeder_test.py b/tensorflow/contrib/training/python/training/feeder_test.py index 96b06118dedf8783b14b12f58a33f0de282963a9..e83743e4800bcb5221b57af1ce3412c696d10ab2 100644 --- a/tensorflow/contrib/training/python/training/feeder_test.py +++ b/tensorflow/contrib/training/python/training/feeder_test.py @@ -282,7 +282,8 @@ class FeederTest(test.TestCase): op_types_by_scope_and_device[scope][dev][op.type] += 1 - expected_ops = collections.Counter({'QueueEnqueue': 1, 'FIFOQueue': 1}) + expected_ops = collections.Counter( + {'QueueEnqueueV2': 1, 'FIFOQueueV2': 1}) expected_enq_devices = [('replica_0', [ '/job:consumer/replica:0/device:cpu:0', '/job:consumer/replica:1/device:cpu:0', diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c4023d2ced91ed97fbad5649112e08e0f83d226a..4b50386e940ad20d2bd0e96cd40bb050f676b351 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -267,6 +267,24 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { return Status::OK(); } +Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { + int start, stop; + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + const TensorValue& value((*params_->inputs)[start]); + if (value.is_ref()) { + *dtype = MakeRefType(value->dtype()); + } else { + *dtype = value->dtype(); + } + return Status::OK(); +} + Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 7318a2dc7d97e63029572eee54e6658e8ed0b1f2..75ad4bb7fc55f33d4f592a89799418f3343a7b1b 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -568,6 +568,7 @@ class OpKernelContext { int num_inputs() const { return params_->inputs->size(); } DataType input_dtype(int index) const; + Status input_dtype(StringPiece name, DataType* dtype) const; int num_outputs() const { return outputs_.size(); } DataType expected_output_dtype(int index) const; diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index b4556c9272d61d5f7e3e6336f5aa5ec9f11c865d..1c561899159e42269b55c81f3838325207623f03 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -334,6 +334,36 @@ TEST_F(OpKernelTest, SaveTempTrue) { delete params.device; } +TEST_F(OpKernelTest, InputDtype) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.record_tensor_accesses = false; + params.device = new DummyDevice(env, params.record_tensor_accesses); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + Tensor a(DT_FLOAT, TensorShape({})); + Tensor b(DT_INT32, TensorShape({})); + Tensor c(DT_UINT8, TensorShape({})); + gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), + TensorValue(&c)}; + params.inputs = &inputs; + OpKernelContext* ctx = new OpKernelContext(¶ms); + + DataType dtype; + EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok()); + ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok()); + EXPECT_EQ(dtype, DT_FLOAT); + ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok()); + EXPECT_EQ(dtype, DT_INT32); + delete ctx; + delete params.device; +} + class OpKernelBuilderTest : public ::testing::Test { protected: // Each attr is described by a "name|type|value". diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index a1053669b75254657368958079bf6cd252df9656..e8d31c6f9e647524a5614015a5f51283bf9bb9bb 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -368,6 +368,13 @@ Status ResourceMgr::Delete(const string& container, const string& name) { template Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, T** resource) { + DataType dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); + if (dtype == DT_RESOURCE) { + const Tensor* handle; + TF_RETURN_IF_ERROR(ctx->input(input_name, &handle)); + return LookupResource(ctx, handle->scalar()(), resource); + } string container; string shared_name; { @@ -479,7 +486,7 @@ template void ResourceHandleOp::Compute(OpKernelContext* ctx) { Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - output->flat()(0) = + output->scalar()() = MakeResourceHandle(ctx, container_, name_); } diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h index 95299c596f193956d719d3acb8b029e7e10ee84c..de65657a9e59ac5e6496fd5d29b03ca06be9738d 100644 --- a/tensorflow/core/framework/resource_op_kernel.h +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -94,7 +94,15 @@ class ResourceOpKernel : public OpKernel { h(1) = cinfo_.name(); resource_ = resource; } - context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); + if (context->expected_output_dtype(0) == DT_RESOURCE) { + Tensor* handle; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &handle)); + handle->scalar()() = + MakeResourceHandle(context, cinfo_.container(), cinfo_.name()); + } else { + context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); + } } protected: diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index e61a14d246fa4f05ff8767d15e2c81873b2f7e13..9acf8cbd2db6a7191a4445fc3b8095db5a38adba 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -347,7 +347,10 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { - TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue")); + if (!MatchesNodeDefOp(node_def, "FIFOQueue").ok() && + !MatchesNodeDefOp(node_def, "FIFOQueueV2").ok()) { + return errors::InvalidArgument("Expected FIFOQueue, found ", node_def.op()); + } TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc index 31df66425407e951829572364fea313421c3293a..b35bdbb2f01e0e02b5d81f49817f24870bc086b6 100644 --- a/tensorflow/core/kernels/fifo_queue_op.cc +++ b/tensorflow/core/kernels/fifo_queue_op.cc @@ -58,5 +58,6 @@ class FIFOQueueOp : public TypedQueueOp { }; REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp); +REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc index 23f57b0edfbcf7298ecc086142e7991fe82c6eb3..9a43dd6d08cefda57b6c0c15d084013797d87205 100644 --- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc +++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc @@ -121,5 +121,7 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU), FixedLengthRecordReaderOp); +REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU), + FixedLengthRecordReaderOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/identity_reader_op.cc b/tensorflow/core/kernels/identity_reader_op.cc index a239bf5d6d34df5efa9792eef88424b93188fc3b..02e4987abb4d6935152995fb98891629df2d835d 100644 --- a/tensorflow/core/kernels/identity_reader_op.cc +++ b/tensorflow/core/kernels/identity_reader_op.cc @@ -68,5 +68,7 @@ class IdentityReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("IdentityReader").Device(DEVICE_CPU), IdentityReaderOp); +REGISTER_KERNEL_BUILDER(Name("IdentityReaderV2").Device(DEVICE_CPU), + IdentityReaderOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index fd1597047cc6bb5ed46e83902069e51828bd809c..07902cd48bc4b88794ec0b949c0a919af8368f39 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -276,7 +276,11 @@ Status PaddingFIFOQueue::CompatibleNodeDefShapes( } Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { - TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PaddingFIFOQueue")); + if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() && + !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) { + return errors::InvalidArgument("Expected PaddingFIFOQueue, found ", + node_def.op()); + } TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def)); diff --git a/tensorflow/core/kernels/padding_fifo_queue_op.cc b/tensorflow/core/kernels/padding_fifo_queue_op.cc index b87b2b90b8ef7d94c01fa51fd9c18db605ada6a4..0c96eaa57b75a64bc6e3ac16402da22cd2954fd7 100644 --- a/tensorflow/core/kernels/padding_fifo_queue_op.cc +++ b/tensorflow/core/kernels/padding_fifo_queue_op.cc @@ -67,5 +67,7 @@ class PaddingFIFOQueueOp : public TypedQueueOp { REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueue").Device(DEVICE_CPU), PaddingFIFOQueueOp); +REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueueV2").Device(DEVICE_CPU), + PaddingFIFOQueueOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/priority_queue.cc b/tensorflow/core/kernels/priority_queue.cc index 095026be2c98702c8ecc334ab3af66a2051bb295..85749a2954600e1659df148b45e53ccfa708840d 100644 --- a/tensorflow/core/kernels/priority_queue.cc +++ b/tensorflow/core/kernels/priority_queue.cc @@ -377,7 +377,11 @@ void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { - TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PriorityQueue")); + if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() && + !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) { + return errors::InvalidArgument("Expected PriorityQueue, found ", + node_def.op()); + } TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def)); TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def)); diff --git a/tensorflow/core/kernels/priority_queue_op.cc b/tensorflow/core/kernels/priority_queue_op.cc index eb4b4f289f0ec2db71531a7d668905e383910167..efa275dfdda45a8c03d2b9c748a8d54f1891655d 100644 --- a/tensorflow/core/kernels/priority_queue_op.cc +++ b/tensorflow/core/kernels/priority_queue_op.cc @@ -64,5 +64,7 @@ class PriorityQueueOp : public TypedQueueOp { REGISTER_KERNEL_BUILDER(Name("PriorityQueue").Device(DEVICE_CPU), PriorityQueueOp); +REGISTER_KERNEL_BUILDER(Name("PriorityQueueV2").Device(DEVICE_CPU), + PriorityQueueOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index cca0d8c38c6bafe090f3af4b2b4289f18a4f36d1..301d1420a438ce104e6cd7095c0991b5ba3099bd 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -33,8 +33,13 @@ class QueueOpKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { QueueInterface* queue; - OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), - callback); + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), + callback); + } ComputeAsync(ctx, queue, [callback, queue]() { queue->Unref(); callback(); @@ -77,7 +82,12 @@ class EnqueueOp : public QueueAccessOpKernel { protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { - DataTypeVector expected_inputs = {DT_STRING_REF}; + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } for (DataType dt : queue->component_dtypes()) { expected_inputs.push_back(dt); } @@ -101,6 +111,7 @@ class EnqueueOp : public QueueAccessOpKernel { }; REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); +REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp); // Defines an EnqueueManyOp, the execution of which slices each // component of a tuple of tensors along the 0th dimension, and @@ -123,7 +134,12 @@ class EnqueueManyOp : public QueueAccessOpKernel { protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { - DataTypeVector expected_inputs = {DT_STRING_REF}; + DataTypeVector expected_inputs; + if (ctx->input_dtype(0) == DT_RESOURCE) { + expected_inputs.push_back(DT_RESOURCE); + } else { + expected_inputs.push_back(DT_STRING_REF); + } for (DataType dt : queue->component_dtypes()) { expected_inputs.push_back(dt); } @@ -150,6 +166,8 @@ class EnqueueManyOp : public QueueAccessOpKernel { REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), EnqueueManyOp); +REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), + EnqueueManyOp); // Defines a DequeueOp, the execution of which dequeues a tuple of // tensors from the given Queue. @@ -166,9 +184,15 @@ class DequeueOp : public QueueAccessOpKernel { protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { - OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), - callback); + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), + callback); + } queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { if (!ctx->status().ok()) { @@ -192,6 +216,7 @@ class DequeueOp : public QueueAccessOpKernel { }; REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); +REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp); // Defines a DequeueManyOp, the execution of which concatenates the // requested number of elements from the given Queue along the 0th @@ -220,9 +245,17 @@ class DequeueManyOp : public QueueAccessOpKernel { num_elements, " < 0 elements"), callback); - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, + queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } queue->TryDequeueMany( num_elements, ctx, false /* allow_small_batch */, @@ -250,6 +283,8 @@ class DequeueManyOp : public QueueAccessOpKernel { REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); +REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), + DequeueManyOp); // Defines a DequeueUpToOp, the execution of which concatenates the // requested number of elements from the given Queue along the 0th @@ -296,9 +331,17 @@ class DequeueUpToOp : public QueueAccessOpKernel { num_elements, " < 0 elements"), callback); - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - queue->component_dtypes()), - callback); + if (ctx->input_dtype(0) == DT_RESOURCE) { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_RESOURCE, DT_INT32}, + queue->component_dtypes()), + callback); + } else { + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + } queue->TryDequeueMany( num_elements, ctx, true /* allow_small_batch */, @@ -326,6 +369,8 @@ class DequeueUpToOp : public QueueAccessOpKernel { REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), DequeueUpToOp); +REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), + DequeueUpToOp); // Defines a QueueCloseOp, which closes the given Queue. Closing a // Queue signals that no more elements will be enqueued in it. @@ -351,6 +396,7 @@ class QueueCloseOp : public QueueOpKernel { }; REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); +REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp); // Defines a QueueSizeOp, which computes the number of elements in the // given Queue, and emits it as an output tensor. @@ -377,5 +423,28 @@ class QueueSizeOp : public QueueOpKernel { }; REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); +REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp); + +class FakeQueueOp : public OpKernel { + public: + explicit FakeQueueOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_STRING, TensorShape({2}), + &handle_, nullptr)); + } + + void Compute(OpKernelContext* context) { + ResourceHandle ref = context->input(0).flat()(0); + handle_.AccessTensor(context)->flat()(0) = ref.container(); + handle_.AccessTensor(context)->flat()(1) = ref.name(); + context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); + } + + private: + mutex mu_; + PersistentTensor handle_; +}; + +REGISTER_KERNEL_BUILDER(Name("FakeQueue").Device(DEVICE_CPU), FakeQueueOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 064d8b9c748aa0e2548532b1d9792f3d3050359e..a973bc2b1c621cc2bd0a6f1fdf1162d64c396946 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -425,7 +425,11 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { - TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue")); + if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() && + !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) { + return errors::InvalidArgument("Expected RandomShuffleQueue, found ", + node_def.op()); + } TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); int32 min_after_dequeue = -1; @@ -497,5 +501,7 @@ class RandomShuffleQueueOp : public TypedQueueOp { REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU), RandomShuffleQueueOp); +REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU), + RandomShuffleQueueOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc index bb8e35cc089dac25b51b8c514be53d52e5a7a348..4b949ef82b9541fba03e462d0e756ed92c1183da 100644 --- a/tensorflow/core/kernels/reader_ops.cc +++ b/tensorflow/core/kernels/reader_ops.cc @@ -97,6 +97,7 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel { }; REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp); +REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp); class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { public: @@ -149,6 +150,8 @@ class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU), ReaderReadUpToOp); +REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2").Device(DEVICE_CPU), + ReaderReadUpToOp); class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel { public: @@ -165,6 +168,8 @@ class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU), ReaderNumRecordsProducedOp); +REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2").Device(DEVICE_CPU), + ReaderNumRecordsProducedOp); class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel { public: @@ -181,6 +186,9 @@ class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU), ReaderNumWorkUnitsCompletedOp); +REGISTER_KERNEL_BUILDER( + Name("ReaderNumWorkUnitsCompletedV2").Device(DEVICE_CPU), + ReaderNumWorkUnitsCompletedOp); class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel { public: @@ -198,6 +206,8 @@ class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU), ReaderSerializeStateOp); +REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2").Device(DEVICE_CPU), + ReaderSerializeStateOp); class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel { public: @@ -217,6 +227,8 @@ class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU), ReaderRestoreStateOp); +REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2").Device(DEVICE_CPU), + ReaderRestoreStateOp); class ReaderResetOp : public ReaderVerbSyncOpKernel { public: @@ -229,5 +241,7 @@ class ReaderResetOp : public ReaderVerbSyncOpKernel { }; REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp); +REGISTER_KERNEL_BUILDER(Name("ReaderResetV2").Device(DEVICE_CPU), + ReaderResetOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc index 4dc70d0cce29a669f372959a485cb1c9049090aa..ffa647d8ef99040933cb481fa36d13d58c5e7e67 100644 --- a/tensorflow/core/kernels/text_line_reader_op.cc +++ b/tensorflow/core/kernels/text_line_reader_op.cc @@ -111,5 +111,7 @@ class TextLineReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU), TextLineReaderOp); +REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2").Device(DEVICE_CPU), + TextLineReaderOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index e169498fd3d34e3f1677180b2d493f280321b9bc..efadd9b7e93b4727f6390f380edcb5290fc8b09f 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -97,5 +97,7 @@ class TFRecordReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU), TFRecordReaderOp); +REGISTER_KERNEL_BUILDER(Name("TFRecordReaderV2").Device(DEVICE_CPU), + TFRecordReaderOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index 538e3bbc9eb72af2fb4621d2ca9a1d6459e00f06..5851fe0a12eed8a3091f1b993265aafa1f3aa6de 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -96,6 +96,8 @@ class WholeFileReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU), WholeFileReaderOp); +REGISTER_KERNEL_BUILDER(Name("WholeFileReaderV2").Device(DEVICE_CPU), + WholeFileReaderOp); class ReadFileOp : public OpKernel { public: diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index c46902995beb4d483fc730a1dfa24753a8b4ac02..ea24a0a16f2c1ad924b9180867eac7e15534320a 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -251,6 +251,41 @@ shared_name: If non-empty, this queue will be shared under the given name across multiple sessions. )doc"); +REGISTER_OP("RandomShuffleQueueV2") + .Output("handle: resource") + .Attr("component_types: list(type) >= 1") + .Attr("shapes: list(shape) >= 0 = []") + .Attr("capacity: int = -1") + .Attr("min_after_dequeue: int = 0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A queue that randomizes the order of elements. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. If the length of + this attr is 0, the shapes of queue elements are not constrained, and + only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +min_after_dequeue: Dequeue will block unless there would be this + many elements after the dequeue or the queue is closed. This + ensures a minimum level of mixing of elements. +seed: If either seed or seed2 is set to be non-zero, the random number + generator is seeded by the given seed. Otherwise, a random seed is used. +seed2: A second seed to avoid seed collision. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + REGISTER_OP("FIFOQueue") .Output("handle: Ref(string)") .Attr("component_types: list(type) >= 1") @@ -277,6 +312,32 @@ shared_name: If non-empty, this queue will be shared under the given name across multiple sessions. )doc"); +REGISTER_OP("FIFOQueueV2") + .Output("handle: resource") + .Attr("component_types: list(type) >= 1") + .Attr("shapes: list(shape) >= 0 = []") + .Attr("capacity: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A queue that produces elements in first-in first-out order. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. If the length of + this attr is 0, the shapes of queue elements are not constrained, and + only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + REGISTER_OP("PaddingFIFOQueue") .Output("handle: Ref(string)") .Attr("component_types: list(type) >= 1") @@ -311,6 +372,40 @@ shared_name: If non-empty, this queue will be shared under the given name across multiple sessions. )doc"); +REGISTER_OP("PaddingFIFOQueueV2") + .Output("handle: resource") + .Attr("component_types: list(type) >= 1") + .Attr("shapes: list(shape) >= 0 = []") + .Attr("capacity: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A queue that produces elements in first-in first-out order. + +Variable-size shapes are allowed by setting the corresponding shape dimensions +to 0 in the shape attr. In this case DequeueMany will pad up to the maximum +size of any given element in the minibatch. See below for details. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. + Shapes of fixed rank but variable size are allowed by setting + any shape dimension to -1. In this case, the inputs' shape may vary along + the given dimension, and DequeueMany will pad the given dimension with + zeros up to the maximum shape of all elements in the given batch. + If the length of this attr is 0, different queue elements may have + different ranks and shapes, but only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + REGISTER_OP("PriorityQueue") .Output("handle: Ref(string)") .Attr("component_types: list(type) >= 0 = []") @@ -343,6 +438,45 @@ shared_name: If non-empty, this queue will be shared under the given name across multiple sessions. )doc"); +REGISTER_OP("PriorityQueueV2") + .Output("handle: resource") + .Attr("component_types: list(type) >= 0 = []") + .Attr("shapes: list(shape) >= 0") + .Attr("capacity: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A queue that produces elements sorted by the first component value. + +Note that the PriorityQueue requires the first component of any element +to be a scalar int64, in addition to the other elements declared by +component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +entry in their input (resp. output) lists. + +handle: The handle to the queue. +component_types: The type of each component in a value. +shapes: The shape of each component in a value. The length of this attr must + be either 0 or the same as the length of component_types. If the length of + this attr is 0, the shapes of queue elements are not constrained, and + only one element may be dequeued at a time. +capacity: The upper bound on the number of elements in this queue. + Negative numbers mean no limit. +container: If non-empty, this queue is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this queue will be shared under the given name + across multiple sessions. +)doc"); + +REGISTER_OP("FakeQueue") + .Input("resource: resource") + .Output("handle: Ref(string)") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc("Deprecated. Do not use."); + REGISTER_OP("QueueEnqueue") .Input("handle: Ref(string)") .Input("components: Tcomponents") @@ -365,6 +499,28 @@ timeout_ms: If the queue is full, this operation will block for up to Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueEnqueueV2") + .Input("handle: resource") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Enqueues a tuple of one or more tensors in the given queue. + +The components input has k elements, which correspond to the components of +tuples stored in the given queue. + +N.B. If the queue is full, this operation will block until the given +element has been enqueued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors from which the enqueued tensors should be taken. +timeout_ms: If the queue is full, this operation will block for up to + timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueEnqueueMany") .Input("handle: Ref(string)") .Input("components: Tcomponents") @@ -392,6 +548,33 @@ timeout_ms: If the queue is too full, this operation will block for up Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueEnqueueManyV2") + .Input("handle: resource") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Enqueues zero or more tuples of one or more tensors in the given queue. + +This operation slices each component tensor along the 0th dimension to +make multiple queue elements. All of the tuple components must have the +same size in the 0th dimension. + +The components input has k elements, which correspond to the components of +tuples stored in the given queue. + +N.B. If the queue is full, this operation will block until the given +elements have been enqueued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors from which the enqueued tensors should + be taken. +timeout_ms: If the queue is too full, this operation will block for up + to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueDequeue") .Input("handle: Ref(string)") .Output("components: component_types") @@ -416,6 +599,30 @@ timeout_ms: If the queue is empty, this operation will block for up to Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueDequeueV2") + .Input("handle: resource") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Dequeues a tuple of one or more tensors from the given queue. + +This operation has k outputs, where k is the number of components +in the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +N.B. If the queue is empty, this operation will block until an element +has been dequeued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue is empty, this operation will block for up to + timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueDequeueMany") .Input("handle: Ref(string)") .Input("n: int32") @@ -449,6 +656,39 @@ timeout_ms: If the queue has fewer than n elements, this operation Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueDequeueManyV2") + .Input("handle: resource") + .Input("n: int32") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Dequeues n tuples of one or more tensors from the given queue. + +If the queue is closed and there are fewer than n elements, then an +OutOfRange error is returned. + +This operation concatenates queue-element component tensors along the +0th dimension to make a single component tensor. All of the components +in the dequeued tuple will have size n in the 0th dimension. + +This operation has k outputs, where k is the number of components in +the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +N.B. If the queue is empty, this operation will block until n elements +have been dequeued (or 'timeout_ms' elapses, if specified). + +handle: The handle to a queue. +n: The number of tuples to dequeue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue has fewer than n elements, this operation + will block for up to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueDequeueUpTo") .Input("handle: Ref(string)") .Input("n: int32") @@ -486,6 +726,43 @@ timeout_ms: If the queue has fewer than n elements, this operation Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueDequeueUpToV2") + .Input("handle: resource") + .Input("n: int32") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Dequeues n tuples of one or more tensors from the given queue. + +This operation is not supported by all queues. If a queue does not support +DequeueUpTo, then an Unimplemented error is returned. + +If the queue is closed and there are more than 0 but less than n elements +remaining, then instead of returning an OutOfRange error like +QueueDequeueMany, less than `n` elements are returned immediately. If the queue +is closed and there are 0 elements left in the queue, then an OutOfRange +error is returned just like in QueueDequeueMany. Otherwise the behavior +is identical to QueueDequeueMany: + +This operation concatenates queue-element component tensors along the +0th dimension to make a single component tensor. All of the components +in the dequeued tuple will have size n in the 0th dimension. + +This operation has k outputs, where k is the number of components in +the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +handle: The handle to a queue. +n: The number of tuples to dequeue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue has fewer than n elements, this operation + will block for up to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueClose") .Input("handle: Ref(string)") .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) @@ -504,6 +781,24 @@ cancel_pending_enqueues: If true, all pending enqueue requests that are blocked on the given queue will be cancelled. )doc"); +REGISTER_OP("QueueCloseV2") + .Input("handle: resource") + .SetShapeFn(shape_inference::NoOutputs) + .Attr("cancel_pending_enqueues: bool = false") + .Doc(R"doc( +Closes the given queue. + +This operation signals that no more elements will be enqueued in the +given queue. Subsequent Enqueue(Many) operations will fail. +Subsequent Dequeue(Many) operations will continue to succeed if +sufficient elements remain in the queue. Subsequent Dequeue(Many) +operations that would block will fail immediately. + +handle: The handle to a queue. +cancel_pending_enqueues: If true, all pending enqueue requests that are + blocked on the given queue will be cancelled. +)doc"); + REGISTER_OP("QueueSize") .Input("handle: Ref(string)") .Output("size: int32") @@ -515,6 +810,17 @@ handle: The handle to a queue. size: The number of elements in the given queue. )doc"); +REGISTER_OP("QueueSizeV2") + .Input("handle: resource") + .Output("size: int32") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Computes the number of elements in the given queue. + +handle: The handle to a queue. +size: The number of elements in the given queue. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("AccumulatorNumAccumulated") diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 1167461e9e50122aa1b06730fe6461115b203ff9..1412aeffc5da127791f79f3aa28857b86bf7ece3 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -378,6 +378,25 @@ shared_name: If non-empty, this reader is named in the given bucket with this shared_name. Otherwise, the node name is used instead. )doc"); +REGISTER_OP("WholeFileReaderV2") + .Output("reader_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A Reader that outputs the entire contents of a file as a value. + +To use, enqueue filenames in a Queue. The output of ReaderRead will +be a filename (key) and the contents of that file (value). + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + REGISTER_OP("TextLineReader") .Output("reader_handle: Ref(string)") .Attr("skip_header_lines: int = 0") @@ -396,6 +415,24 @@ shared_name: If non-empty, this reader is named in the given bucket with this shared_name. Otherwise, the node name is used instead. )doc"); +REGISTER_OP("TextLineReaderV2") + .Output("reader_handle: resource") + .Attr("skip_header_lines: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A Reader that outputs the lines of a file delimited by '\n'. + +reader_handle: The handle to reference the Reader. +skip_header_lines: Number of lines to skip from the beginning of every file. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + REGISTER_OP("FixedLengthRecordReader") .Output("reader_handle: Ref(string)") .Attr("header_bytes: int = 0") @@ -415,6 +452,25 @@ shared_name: If non-empty, this reader is named in the given bucket with this shared_name. Otherwise, the node name is used instead. )doc"); +REGISTER_OP("FixedLengthRecordReaderV2") + .Output("reader_handle: resource") + .Attr("header_bytes: int = 0") + .Attr("record_bytes: int") + .Attr("footer_bytes: int = 0") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A Reader that outputs fixed-length records from a file. + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + REGISTER_OP("TFRecordReader") .Output("reader_handle: Ref(string)") .Attr("container: string = ''") @@ -432,6 +488,23 @@ shared_name: If non-empty, this reader is named in the given bucket with this shared_name. Otherwise, the node name is used instead. )doc"); +REGISTER_OP("TFRecordReaderV2") + .Output("reader_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("compression_type: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A Reader that outputs the records from a TensorFlow Records file. + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + REGISTER_OP("IdentityReader") .Output("reader_handle: Ref(string)") .Attr("container: string = ''") @@ -451,6 +524,25 @@ shared_name: If non-empty, this reader is named in the given bucket with this shared_name. Otherwise, the node name is used instead. )doc"); +REGISTER_OP("IdentityReaderV2") + .Output("reader_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A Reader that outputs the queued work as both the key and value. + +To use, enqueue strings in a Queue. ReaderRead will take the front +work string and output (work, work). + +reader_handle: The handle to reference the Reader. +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +)doc"); + // Ops that operate on Readers ------------------------------------------------ REGISTER_OP("ReaderRead") @@ -472,6 +564,25 @@ key: A scalar. value: A scalar. )doc"); +REGISTER_OP("ReaderReadV2") + .Input("reader_handle: resource") + .Input("queue_handle: resource") + .Output("key: string") + .Output("value: string") + .SetShapeFn(ScalarInputsAndOutputs) + .Doc(R"doc( +Returns the next record (key, value pair) produced by a Reader. + +Will dequeue from the input queue if necessary (e.g. when the +Reader needs to start reading from a new file since it has finished +with the previous file). + +reader_handle: Handle to a Reader. +queue_handle: Handle to a Queue, with string work items. +key: A scalar. +value: A scalar. +)doc"); + REGISTER_OP("ReaderReadUpTo") .Input("reader_handle: Ref(string)") .Input("queue_handle: Ref(string)") @@ -503,6 +614,37 @@ keys: A 1-D tensor. values: A 1-D tensor. )doc"); +REGISTER_OP("ReaderReadUpToV2") + .Input("reader_handle: resource") + .Input("queue_handle: resource") + .Input("num_records: int64") + .Output("keys: string") + .Output("values: string") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, out); + c->set_output(1, out); + return Status::OK(); + }) + .Doc(R"doc( +Returns up to `num_records` (key, value) pairs produced by a Reader. + +Will dequeue from the input queue if necessary (e.g. when the +Reader needs to start reading from a new file since it has finished +with the previous file). +It may return less than `num_records` even before the last batch. + +reader_handle: Handle to a `Reader`. +queue_handle: Handle to a `Queue`, with string work items. +num_records: number of records to read from `Reader`. +keys: A 1-D tensor. +values: A 1-D tensor. +)doc"); + REGISTER_OP("ReaderNumRecordsProduced") .Input("reader_handle: Ref(string)") .Output("records_produced: int64") @@ -516,6 +658,19 @@ succeeded. reader_handle: Handle to a Reader. )doc"); +REGISTER_OP("ReaderNumRecordsProducedV2") + .Input("reader_handle: resource") + .Output("records_produced: int64") + .SetShapeFn(ScalarInputsAndOutputs) + .Doc(R"doc( +Returns the number of records this Reader has produced. + +This is the same as the number of ReaderRead executions that have +succeeded. + +reader_handle: Handle to a Reader. +)doc"); + REGISTER_OP("ReaderNumWorkUnitsCompleted") .Input("reader_handle: Ref(string)") .Output("units_completed: int64") @@ -526,6 +681,16 @@ Returns the number of work units this Reader has finished processing. reader_handle: Handle to a Reader. )doc"); +REGISTER_OP("ReaderNumWorkUnitsCompletedV2") + .Input("reader_handle: resource") + .Output("units_completed: int64") + .SetShapeFn(ScalarInputsAndOutputs) + .Doc(R"doc( +Returns the number of work units this Reader has finished processing. + +reader_handle: Handle to a Reader. +)doc"); + REGISTER_OP("ReaderSerializeState") .Input("reader_handle: Ref(string)") .Output("state: string") @@ -539,6 +704,19 @@ Unimplemented error. reader_handle: Handle to a Reader. )doc"); +REGISTER_OP("ReaderSerializeStateV2") + .Input("reader_handle: resource") + .Output("state: string") + .SetShapeFn(ScalarInputsAndOutputs) + .Doc(R"doc( +Produce a string tensor that encodes the state of a Reader. + +Not all Readers support being serialized, so this can produce an +Unimplemented error. + +reader_handle: Handle to a Reader. +)doc"); + REGISTER_OP("ReaderRestoreState") .Input("reader_handle: Ref(string)") .Input("state: string") @@ -563,6 +741,26 @@ state: Result of a ReaderSerializeState of a Reader with type matching reader_handle. )doc"); +REGISTER_OP("ReaderRestoreStateV2") + .Input("reader_handle: resource") + .Input("state: string") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Restore a reader to a previously saved state. + +Not all Readers support being restored, so this can produce an +Unimplemented error. + +reader_handle: Handle to a Reader. +state: Result of a ReaderSerializeState of a Reader with type + matching reader_handle. +)doc"); + REGISTER_OP("ReaderReset") .Input("reader_handle: Ref(string)") .SetShapeFn(TwoElementVectorAndScalarOutputs) @@ -572,6 +770,15 @@ Restore a Reader to its initial clean state. reader_handle: Handle to a Reader. )doc"); +REGISTER_OP("ReaderResetV2") + .Input("reader_handle: resource") + .SetShapeFn(ScalarInputsAndOutputs) + .Doc(R"doc( +Restore a Reader to its initial clean state. + +reader_handle: Handle to a Reader. +)doc"); + // Other input Ops ---------------------------------------------------------- REGISTER_OP("ReadFile") diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 78a77d38edbf4b45219f8a114cae7f673f3bd25d..7cc5854a307b9632e0a929ff1d544a4a297cfd48 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -297,6 +297,8 @@ class DType(object): @property def size(self): + if self._type_enum == types_pb2.DT_RESOURCE: + return 1 return np.dtype(self.as_numpy_dtype).itemsize # Define data type range of numpy dtype diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 2b790b4a92085694ac980793ef71a5cab56a6a81..6f0896df50ad2a184102a042311bf8f4fa4aeb83 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -45,7 +45,7 @@ class FIFOQueueTest(test.TestCase): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'FIFOQueue' + name:'Q' op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } attr { key: 'shapes' value { list {} } } attr { key: 'capacity' value { i: 10 } } @@ -61,7 +61,7 @@ class FIFOQueueTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'FIFOQueue' + name:'Q' op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } @@ -80,7 +80,7 @@ class FIFOQueueTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'FIFOQueue' + name:'Q' op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } @@ -1184,44 +1184,44 @@ class FIFOQueueTest(test.TestCase): with self.test_session(): q_a_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_a") q_a_2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32, shared_name="q_a") - q_a_1.queue_ref.eval() + q_a_1.queue_ref.op.run() with self.assertRaisesOpError("capacity"): - q_a_2.queue_ref.eval() + q_a_2.queue_ref.op.run() q_b_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_b") q_b_2 = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shared_name="q_b") - q_b_1.queue_ref.eval() + q_b_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_b_2.queue_ref.eval() + q_b_2.queue_ref.op.run() q_c_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_c") q_c_2 = data_flow_ops.FIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_c") - q_c_1.queue_ref.eval() + q_c_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_c_2.queue_ref.eval() + q_c_2.queue_ref.op.run() q_d_1 = data_flow_ops.FIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d") q_d_2 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_d") - q_d_1.queue_ref.eval() + q_d_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_d_2.queue_ref.eval() + q_d_2.queue_ref.op.run() q_e_1 = data_flow_ops.FIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e") q_e_2 = data_flow_ops.FIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_e") - q_e_1.queue_ref.eval() + q_e_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_e_2.queue_ref.eval() + q_e_2.queue_ref.op.run() q_f_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_f") q_f_2 = data_flow_ops.FIFOQueue( 10, (dtypes_lib.float32, dtypes_lib.int32), shared_name="q_f") - q_f_1.queue_ref.eval() + q_f_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_f_2.queue_ref.eval() + q_f_2.queue_ref.op.run() def testSelectQueue(self): with self.test_session(): @@ -1241,7 +1241,7 @@ class FIFOQueueTest(test.TestCase): q1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32) enq_q = data_flow_ops.FIFOQueue.from_list(3, [q1, q2]) - with self.assertRaisesOpError("Index must be in the range"): + with self.assertRaisesOpError("is not in"): enq_q.dequeue().eval() def _blockingDequeue(self, sess, dequeue_op): @@ -1389,16 +1389,6 @@ class FIFOQueueTest(test.TestCase): for (input_elem, output_elem) in zip(input_tuple, output_tuple): self.assertAllEqual(input_elem, output_elem) - def testDeviceColocation(self): - with ops.device("/job:ps"): - q = data_flow_ops.FIFOQueue(32, [dtypes_lib.int32], name="q") - - with ops.device("/job:worker/task:7"): - dequeued_t = q.dequeue() - - self.assertDeviceEqual("/job:ps", dequeued_t.device) - self.assertEqual([b"loc:@q"], dequeued_t.op.colocation_groups()) - class FIFOQueueDictTest(test.TestCase): @@ -1411,7 +1401,7 @@ class FIFOQueueDictTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'FIFOQueue' + name:'Q' op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } @@ -1432,7 +1422,7 @@ class FIFOQueueDictTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'FIFOQueue' + name:'Q' op:'FIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py index 5cd757159027eb57cf9c146b8bcd4f26a2c43ac4..53b1897f488636683a5a03f0cb3b95340fa4b25c 100644 --- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py @@ -42,7 +42,7 @@ class PaddingFIFOQueueTest(test.TestCase): 10, dtypes_lib.float32, ((None,),), name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'PaddingFIFOQueue' + name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { type: DT_FLOAT } } } attr { key: 'shapes' value { list { shape { dim { size: -1 } } } } } attr { key: 'capacity' value { i: 10 } } @@ -58,7 +58,7 @@ class PaddingFIFOQueueTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'PaddingFIFOQueue' + name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } @@ -77,7 +77,7 @@ class PaddingFIFOQueueTest(test.TestCase): name="Q") self.assertTrue(isinstance(q.queue_ref, ops.Tensor)) self.assertProtoEquals(""" - name:'Q' op:'PaddingFIFOQueue' + name:'Q' op:'PaddingFIFOQueueV2' attr { key: 'component_types' value { list { type: DT_INT32 type : DT_FLOAT } } } @@ -1307,50 +1307,50 @@ class PaddingFIFOQueueTest(test.TestCase): 10, dtypes_lib.float32, ((),), shared_name="q_a") q_a_2 = data_flow_ops.PaddingFIFOQueue( 15, dtypes_lib.float32, ((),), shared_name="q_a") - q_a_1.queue_ref.eval() + q_a_1.queue_ref.op.run() with self.assertRaisesOpError("capacity"): - q_a_2.queue_ref.eval() + q_a_2.queue_ref.op.run() q_b_1 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((),), shared_name="q_b") q_b_2 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.int32, ((),), shared_name="q_b") - q_b_1.queue_ref.eval() + q_b_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_b_2.queue_ref.eval() + q_b_2.queue_ref.op.run() q_c_1 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((),), shared_name="q_c") q_c_2 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_c") - q_c_1.queue_ref.eval() + q_c_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_c_2.queue_ref.eval() + q_c_2.queue_ref.op.run() q_d_1 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d") q_d_2 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((),), shared_name="q_d") - q_d_1.queue_ref.eval() + q_d_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_d_2.queue_ref.eval() + q_d_2.queue_ref.op.run() q_e_1 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e") q_e_2 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_e") - q_e_1.queue_ref.eval() + q_e_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_e_2.queue_ref.eval() + q_e_2.queue_ref.op.run() q_f_1 = data_flow_ops.PaddingFIFOQueue( 10, dtypes_lib.float32, ((),), shared_name="q_f") q_f_2 = data_flow_ops.PaddingFIFOQueue( 10, (dtypes_lib.float32, dtypes_lib.int32), ((), ()), shared_name="q_f") - q_f_1.queue_ref.eval() + q_f_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_f_2.queue_ref.eval() + q_f_2.queue_ref.op.run() def testSelectQueue(self): with self.test_session(): @@ -1371,7 +1371,7 @@ class PaddingFIFOQueueTest(test.TestCase): q1 = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q2 = data_flow_ops.PaddingFIFOQueue(15, dtypes_lib.float32, ((),)) enq_q = data_flow_ops.PaddingFIFOQueue.from_list(3, [q1, q2]) - with self.assertRaisesOpError("Index must be in the range"): + with self.assertRaisesOpError("is not in"): enq_q.dequeue().eval() def _blockingDequeue(self, sess, dequeue_op): diff --git a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py index 8a92a4d0f0fd10420d39ba67e78c15f1c3e97f49..c9b983662da977ced06a6c487d2f0328c28c9613 100644 --- a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py +++ b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py @@ -1125,65 +1125,65 @@ class RandomShuffleQueueTest(test.TestCase): 10, 5, dtypes_lib.float32, shared_name="q_a") q_a_2 = data_flow_ops.RandomShuffleQueue( 15, 5, dtypes_lib.float32, shared_name="q_a") - q_a_1.queue_ref.eval() + q_a_1.queue_ref.op.run() with self.assertRaisesOpError("capacity"): - q_a_2.queue_ref.eval() + q_a_2.queue_ref.op.run() q_b_1 = data_flow_ops.RandomShuffleQueue( 10, 0, dtypes_lib.float32, shared_name="q_b") q_b_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shared_name="q_b") - q_b_1.queue_ref.eval() + q_b_1.queue_ref.op.run() with self.assertRaisesOpError("min_after_dequeue"): - q_b_2.queue_ref.eval() + q_b_2.queue_ref.op.run() q_c_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shared_name="q_c") q_c_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.int32, shared_name="q_c") - q_c_1.queue_ref.eval() + q_c_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_c_2.queue_ref.eval() + q_c_2.queue_ref.op.run() q_d_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shared_name="q_d") q_d_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d") - q_d_1.queue_ref.eval() + q_d_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_d_2.queue_ref.eval() + q_d_2.queue_ref.op.run() q_e_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e") q_e_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shared_name="q_e") - q_e_1.queue_ref.eval() + q_e_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_e_2.queue_ref.eval() + q_e_2.queue_ref.op.run() q_f_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_f") q_f_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_f") - q_f_1.queue_ref.eval() + q_f_1.queue_ref.op.run() with self.assertRaisesOpError("component shapes"): - q_f_2.queue_ref.eval() + q_f_2.queue_ref.op.run() q_g_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, shared_name="q_g") q_g_2 = data_flow_ops.RandomShuffleQueue( 10, 5, (dtypes_lib.float32, dtypes_lib.int32), shared_name="q_g") - q_g_1.queue_ref.eval() + q_g_1.queue_ref.op.run() with self.assertRaisesOpError("component types"): - q_g_2.queue_ref.eval() + q_g_2.queue_ref.op.run() q_h_1 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, seed=12, shared_name="q_h") q_h_2 = data_flow_ops.RandomShuffleQueue( 10, 5, dtypes_lib.float32, seed=21, shared_name="q_h") - q_h_1.queue_ref.eval() + q_h_1.queue_ref.op.run() with self.assertRaisesOpError("random seeds"): - q_h_2.queue_ref.eval() + q_h_2.queue_ref.op.run() def testSelectQueue(self): with self.test_session(): @@ -1204,7 +1204,7 @@ class RandomShuffleQueueTest(test.TestCase): q1 = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32) q2 = data_flow_ops.RandomShuffleQueue(15, 0, dtypes_lib.float32) enq_q = data_flow_ops.RandomShuffleQueue.from_list(3, [q1, q2]) - with self.assertRaisesOpError("Index must be in the range"): + with self.assertRaisesOpError("is not in"): enq_q.dequeue().eval() def _blockingDequeue(self, sess, dequeue_op): diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 718355422bdc865a0fc3bd5bff2146222222f4ef..4c8a841985641a2c81c6437c8acd109a514594f7 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -205,8 +205,8 @@ class QueueBase(object): reduced_shapes = [ six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)] - queue_refs = [x.queue_ref for x in queues] - selected_queue = control_flow_ops.ref_select(index, queue_refs) + queue_refs = array_ops.stack([x.queue_ref for x in queues]) + selected_queue = array_ops.gather(queue_refs, index) return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names, queue_ref=selected_queue) @@ -326,7 +326,12 @@ class QueueBase(object): for val, shape in zip(vals, self._shapes): val.get_shape().assert_is_compatible_with(shape) - return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope) + if self._queue_ref.dtype == _dtypes.resource: + return gen_data_flow_ops._queue_enqueue_v2( + self._queue_ref, vals, name=scope) + else: + return gen_data_flow_ops._queue_enqueue( + self._queue_ref, vals, name=scope) def enqueue_many(self, vals, name=None): """Enqueues zero or more elements to this queue. @@ -367,7 +372,7 @@ class QueueBase(object): val.get_shape().with_rank_at_least(1)[0]) val.get_shape()[1:].assert_is_compatible_with(shape) - return gen_data_flow_ops._queue_enqueue_many( + return gen_data_flow_ops._queue_enqueue_many_v2( self._queue_ref, vals, name=scope) def _dequeue_return_value(self, tensors): @@ -415,8 +420,12 @@ class QueueBase(object): """ if name is None: name = "%s_Dequeue" % self._name - ret = gen_data_flow_ops._queue_dequeue( - self._queue_ref, self._dtypes, name=name) + if self._queue_ref.dtype == _dtypes.resource: + ret = gen_data_flow_ops._queue_dequeue_v2( + self._queue_ref, self._dtypes, name=name) + else: + ret = gen_data_flow_ops._queue_dequeue( + self._queue_ref, self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the `QueueBase` object. @@ -454,7 +463,7 @@ class QueueBase(object): if name is None: name = "%s_DequeueMany" % self._name - ret = gen_data_flow_ops._queue_dequeue_many( + ret = gen_data_flow_ops._queue_dequeue_many_v2( self._queue_ref, n=n, component_types=self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to @@ -495,7 +504,7 @@ class QueueBase(object): if name is None: name = "%s_DequeueUpTo" % self._name - ret = gen_data_flow_ops._queue_dequeue_up_to( + ret = gen_data_flow_ops._queue_dequeue_up_to_v2( self._queue_ref, n=n, component_types=self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to @@ -529,9 +538,14 @@ class QueueBase(object): """ if name is None: name = "%s_Close" % self._name - return gen_data_flow_ops._queue_close( - self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, - name=name) + if self._queue_ref.dtype == _dtypes.resource: + return gen_data_flow_ops._queue_close_v2( + self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, + name=name) + else: + return gen_data_flow_ops._queue_close( + self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, + name=name) def size(self, name=None): """Compute the number of elements in this queue. @@ -544,7 +558,10 @@ class QueueBase(object): """ if name is None: name = "%s_Size" % self._name - return gen_data_flow_ops._queue_size(self._queue_ref, name=name) + if self._queue_ref.dtype == _dtypes.resource: + return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name) + else: + return gen_data_flow_ops._queue_size(self._queue_ref, name=name) class RandomShuffleQueue(QueueBase): @@ -614,7 +631,7 @@ class RandomShuffleQueue(QueueBase): # the id of the last op created.) string = (str(seed1) + shared_name).encode("utf-8") seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF - queue_ref = gen_data_flow_ops._random_shuffle_queue( + queue_ref = gen_data_flow_ops._random_shuffle_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, shared_name=shared_name, name=name) @@ -665,7 +682,7 @@ class FIFOQueue(QueueBase): dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) names = _as_name_list(names, dtypes) - queue_ref = gen_data_flow_ops._fifo_queue( + queue_ref = gen_data_flow_ops._fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, shared_name=shared_name, name=name) @@ -732,7 +749,7 @@ class PaddingFIFOQueue(QueueBase): "but received %d dtypes and %d shapes." % (len(dtypes), len(shapes))) - queue_ref = gen_data_flow_ops._padding_fifo_queue( + queue_ref = gen_data_flow_ops._padding_fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, shared_name=shared_name, name=name) @@ -788,7 +805,7 @@ class PriorityQueue(QueueBase): types = _as_type_list(types) shapes = _as_shape_list(shapes, types) - queue_ref = gen_data_flow_ops._priority_queue( + queue_ref = gen_data_flow_ops._priority_queue_v2( component_types=types, shapes=shapes, capacity=capacity, shared_name=shared_name, name=name) diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 72ba4031ae6b09c45aa75e4363fd0acc9d1100e1..abe5c538d0a681f2cff67eca8fa89e044e947895 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -58,8 +58,12 @@ BarrierIncompleteSize BarrierInsertMany BarrierReadySize BarrierTakeMany -PriorityQueue +DeleteSessionTensor +FakeQueue FIFOQueue +FIFOQueueV2 +GetSessionHandle +GetSessionTensor HashTable InitializeTable InitializeTableFromTextFile @@ -75,41 +79,52 @@ Mutex MutexAcquire MutexRelease PaddingFIFOQueue +PaddingFIFOQueueV2 +PriorityQueue +PriorityQueueV2 QueueClose +QueueCloseV2 QueueDequeue +QueueDequeueV2 QueueDequeueMany +QueueDequeueManyV2 QueueDequeueUpTo +QueueDequeueUpToV2 QueueEnqueue +QueueEnqueueV2 QueueEnqueueMany +QueueEnqueueManyV2 QueueSize +QueueSizeV2 RandomShuffleQueue +RandomShuffleQueueV2 Stack +StackClose StackPop StackPush -StackClose TensorArray TensorArrayClose -TensorArrayConcat -TensorArrayGather -TensorArrayGrad -TensorArrayRead -TensorArrayPack -TensorArrayScatter -TensorArraySize -TensorArraySplit -TensorArrayUnpack -TensorArrayWrite -TensorArrayV2 TensorArrayCloseV2 +TensorArrayConcat TensorArrayConcatV2 +TensorArrayGather TensorArrayGatherV2 +TensorArrayGrad TensorArrayGradV2 -TensorArrayReadV2 +TensorArrayPack TensorArrayPackV2 +TensorArrayRead +TensorArrayReadV2 +TensorArrayScatter TensorArrayScatterV2 +TensorArraySize TensorArraySizeV2 +TensorArraySplit TensorArraySplitV2 +TensorArrayUnpack TensorArrayUnpackV2 +TensorArrayV2 +TensorArrayWrite TensorArrayWriteV2 TensorArrayV3 TensorArrayCloseV3 @@ -123,9 +138,6 @@ TensorArraySizeV3 TensorArraySplitV3 TensorArrayUnpackV3 TensorArrayWriteV3 -GetSessionHandle -GetSessionTensor -DeleteSessionTensor # functional_ops SymbolicGradient @@ -150,6 +162,18 @@ ReaderReset ReaderRestoreState ReaderSerializeState ReaderWorkQueueLength +FixedLengthRecordReaderV2 +IdentityReaderV2 +ReaderCloseV2 +ReaderEnqueueWorkV2 +ReaderNumRecordsProducedV2 +ReaderNumWorkUnitsCompletedV2 +ReaderReadV2 +ReaderReadUpToV2 +ReaderResetV2 +ReaderRestoreStateV2 +ReaderSerializeStateV2 +ReaderWorkQueueLengthV2 Restore RestoreSlice Save @@ -159,6 +183,9 @@ ShardedFilespec TextLineReader TFRecordReader WholeFileReader +TextLineReaderV2 +TFRecordReaderV2 +WholeFileReaderV2 # linalg_ops BatchCholesky diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index e7af6bfe2d9be1317f0805e4265a0b70c8b437ea..0a099ae28c3ef3bda29c314933ca2c083a402271 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -150,6 +150,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_io_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -267,7 +268,13 @@ class ReaderBase(object): queue_ref = queue else: queue_ref = queue.queue_ref - return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name) + else: + # For compatibility with pre-resource queues, create a ref(string) tensor + # which can be looked up as the same queue by a resource manager. + old_queue_op = gen_data_flow_ops._fake_queue(queue_ref) + return gen_io_ops._reader_read(self._reader_ref, old_queue_op, name=name) def read_up_to(self, queue, num_records, # pylint: disable=invalid-name name=None): @@ -293,10 +300,19 @@ class ReaderBase(object): queue_ref = queue else: queue_ref = queue.queue_ref - return gen_io_ops._reader_read_up_to(self._reader_ref, - queue_ref, - num_records, - name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_read_up_to_v2(self._reader_ref, + queue_ref, + num_records, + name=name) + else: + # For compatibility with pre-resource queues, create a ref(string) tensor + # which can be looked up as the same queue by a resource manager. + old_queue_op = gen_data_flow_ops._fake_queue(queue_ref) + return gen_io_ops._reader_read_up_to_v2(self._reader_ref, + old_queue_op, + num_records, + name=name) def num_records_produced(self, name=None): """Returns the number of records this reader has produced. @@ -311,7 +327,12 @@ class ReaderBase(object): An int64 Tensor. """ - return gen_io_ops._reader_num_records_produced(self._reader_ref, name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_num_records_produced_v2(self._reader_ref, + name=name) + else: + return gen_io_ops._reader_num_records_produced(self._reader_ref, + name=name) def num_work_units_completed(self, name=None): """Returns the number of work units this reader has finished processing. @@ -322,8 +343,12 @@ class ReaderBase(object): Returns: An int64 Tensor. """ - return gen_io_ops._reader_num_work_units_completed(self._reader_ref, - name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_num_work_units_completed_v2(self._reader_ref, + name=name) + else: + return gen_io_ops._reader_num_work_units_completed(self._reader_ref, + name=name) def serialize_state(self, name=None): """Produce a string tensor that encodes the state of a reader. @@ -337,7 +362,10 @@ class ReaderBase(object): Returns: A string Tensor. """ - return gen_io_ops._reader_serialize_state(self._reader_ref, name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_serialize_state_v2(self._reader_ref, name=name) + else: + return gen_io_ops._reader_serialize_state(self._reader_ref, name=name) def restore_state(self, state, name=None): """Restore a reader to a previously saved state. @@ -353,7 +381,12 @@ class ReaderBase(object): Returns: The created Operation. """ - return gen_io_ops._reader_restore_state(self._reader_ref, state, name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_restore_state_v2( + self._reader_ref, state, name=name) + else: + return gen_io_ops._reader_restore_state( + self._reader_ref, state, name=name) @property def supports_serialize(self): @@ -369,7 +402,10 @@ class ReaderBase(object): Returns: The created Operation. """ - return gen_io_ops._reader_reset(self._reader_ref, name=name) + if self._reader_ref.dtype == dtypes.resource: + return gen_io_ops._reader_reset_v2(self._reader_ref, name=name) + else: + return gen_io_ops._reader_reset(self._reader_ref, name=name) ops.NotDifferentiable("ReaderRead") @@ -396,7 +432,7 @@ class WholeFileReader(ReaderBase): Args: name: A name for the operation (optional). """ - rr = gen_io_ops._whole_file_reader(name=name) + rr = gen_io_ops._whole_file_reader_v2(name=name) super(WholeFileReader, self).__init__(rr, supports_serialize=True) @@ -419,8 +455,8 @@ class TextLineReader(ReaderBase): to skip from the beginning of every file. name: A name for the operation (optional). """ - rr = gen_io_ops._text_line_reader(skip_header_lines=skip_header_lines, - name=name) + rr = gen_io_ops._text_line_reader_v2(skip_header_lines=skip_header_lines, + name=name) super(TextLineReader, self).__init__(rr) @@ -444,7 +480,7 @@ class FixedLengthRecordReader(ReaderBase): footer_bytes: An optional int. Defaults to 0. name: A name for the operation (optional). """ - rr = gen_io_ops._fixed_length_record_reader( + rr = gen_io_ops._fixed_length_record_reader_v2( record_bytes=record_bytes, header_bytes=header_bytes, footer_bytes=footer_bytes, name=name) super(FixedLengthRecordReader, self).__init__(rr) @@ -470,7 +506,7 @@ class TFRecordReader(ReaderBase): compression_type = python_io.TFRecordOptions.get_compression_type_string( options) - rr = gen_io_ops._tf_record_reader( + rr = gen_io_ops._tf_record_reader_v2( name=name, compression_type=compression_type) super(TFRecordReader, self).__init__(rr) @@ -493,7 +529,7 @@ class IdentityReader(ReaderBase): Args: name: A name for the operation (optional). """ - rr = gen_io_ops._identity_reader(name=name) + rr = gen_io_ops._identity_reader_v2(name=name) super(IdentityReader, self).__init__(rr, supports_serialize=True)