提交 24825648 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adds V2 versions of Queue and Reader ops using ResourceHandles.

Change: 143570831
上级 354972df
......@@ -217,11 +217,11 @@ class GraphIOTest(test.TestCase):
parse_example_queue_name = "%s/fifo_queue" % name
op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "FIFOQueue",
parse_example_queue_name: "FIFOQueue",
name: "QueueDequeueMany"
file_name_queue_name: "FIFOQueueV2",
"%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
example_queue_name: "FIFOQueueV2",
parse_example_queue_name: "FIFOQueueV2",
name: "QueueDequeueManyV2"
}, g)
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
self.assertEqual(queue_capacity,
......@@ -250,10 +250,10 @@ class GraphIOTest(test.TestCase):
example_queue_name = "%s/random_shuffle_queue" % name
op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "RandomShuffleQueue",
name: "QueueDequeueUpTo",
file_name_queue_name: "FIFOQueueV2",
"%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
example_queue_name: "RandomShuffleQueueV2",
name: "QueueDequeueUpToV2",
file_name_queue_limit_name: "VariableV2"
}, g)
self.assertEqual(
......@@ -281,10 +281,10 @@ class GraphIOTest(test.TestCase):
example_queue_name = "%s/random_shuffle_queue" % name
op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "RandomShuffleQueue",
name: "QueueDequeueMany"
file_name_queue_name: "FIFOQueueV2",
"%s/read/TFRecordReaderV2" % name: "TFRecordReaderV2",
example_queue_name: "RandomShuffleQueueV2",
name: "QueueDequeueManyV2"
}, g)
self.assertEqual(
set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
......@@ -427,10 +427,10 @@ class GraphIOTest(test.TestCase):
example_queue_name = "%s/fifo_queue" % name
test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TextLineReader" % name: "TextLineReader",
example_queue_name: "FIFOQueue",
name: "QueueDequeueUpTo"
file_name_queue_name: "FIFOQueueV2",
"%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
example_queue_name: "FIFOQueueV2",
name: "QueueDequeueUpToV2"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
......@@ -473,10 +473,10 @@ class GraphIOTest(test.TestCase):
example_queue_name = "%s/fifo_queue" % name
worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
test_util.assert_ops_in_graph({
"%s/read/TextLineReader" % name: "TextLineReader",
example_queue_name: "FIFOQueue",
worker_file_name_queue_name: "FIFOQueue",
name: "QueueDequeueUpTo"
"%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
example_queue_name: "FIFOQueueV2",
worker_file_name_queue_name: "FIFOQueueV2",
name: "QueueDequeueUpToV2"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
......
......@@ -282,7 +282,8 @@ class FeederTest(test.TestCase):
op_types_by_scope_and_device[scope][dev][op.type] += 1
expected_ops = collections.Counter({'QueueEnqueue': 1, 'FIFOQueue': 1})
expected_ops = collections.Counter(
{'QueueEnqueueV2': 1, 'FIFOQueueV2': 1})
expected_enq_devices = [('replica_0', [
'/job:consumer/replica:0/device:cpu:0',
'/job:consumer/replica:1/device:cpu:0',
......
......@@ -267,6 +267,24 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
return Status::OK();
}
Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
const TensorValue& value((*params_->inputs)[start]);
if (value.is_ref()) {
*dtype = MakeRefType(value->dtype());
} else {
*dtype = value->dtype();
}
return Status::OK();
}
Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
......
......@@ -568,6 +568,7 @@ class OpKernelContext {
int num_inputs() const { return params_->inputs->size(); }
DataType input_dtype(int index) const;
Status input_dtype(StringPiece name, DataType* dtype) const;
int num_outputs() const { return outputs_.size(); }
DataType expected_output_dtype(int index) const;
......
......@@ -334,6 +334,36 @@ TEST_F(OpKernelTest, SaveTempTrue) {
delete params.device;
}
TEST_F(OpKernelTest, InputDtype) {
Env* env = Env::Default();
OpKernelContext::Params params;
params.record_tensor_accesses = false;
params.device = new DummyDevice(env, params.record_tensor_accesses);
Status status;
std::unique_ptr<OpKernel> op(
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok());
params.op_kernel = op.get();
Tensor a(DT_FLOAT, TensorShape({}));
Tensor b(DT_INT32, TensorShape({}));
Tensor c(DT_UINT8, TensorShape({}));
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
TensorValue(&c)};
params.inputs = &inputs;
OpKernelContext* ctx = new OpKernelContext(&params);
DataType dtype;
EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
EXPECT_EQ(dtype, DT_FLOAT);
ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
EXPECT_EQ(dtype, DT_INT32);
delete ctx;
delete params.device;
}
class OpKernelBuilderTest : public ::testing::Test {
protected:
// Each attr is described by a "name|type|value".
......
......@@ -368,6 +368,13 @@ Status ResourceMgr::Delete(const string& container, const string& name) {
template <typename T>
Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
T** resource) {
DataType dtype;
TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
if (dtype == DT_RESOURCE) {
const Tensor* handle;
TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
}
string container;
string shared_name;
{
......@@ -479,7 +486,7 @@ template <typename T>
void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
output->flat<ResourceHandle>()(0) =
output->scalar<ResourceHandle>()() =
MakeResourceHandle<T>(ctx, container_, name_);
}
......
......@@ -94,7 +94,15 @@ class ResourceOpKernel : public OpKernel {
h(1) = cinfo_.name();
resource_ = resource;
}
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
if (context->expected_output_dtype(0) == DT_RESOURCE) {
Tensor* handle;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() =
MakeResourceHandle<T>(context, cinfo_.container(), cinfo_.name());
} else {
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
}
}
protected:
......
......@@ -347,7 +347,10 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
}
Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue"));
if (!MatchesNodeDefOp(node_def, "FIFOQueue").ok() &&
!MatchesNodeDefOp(node_def, "FIFOQueueV2").ok()) {
return errors::InvalidArgument("Expected FIFOQueue, found ", node_def.op());
}
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
......
......@@ -58,5 +58,6 @@ class FIFOQueueOp : public TypedQueueOp {
};
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
} // namespace tensorflow
......@@ -121,5 +121,7 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU),
FixedLengthRecordReaderOp);
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU),
FixedLengthRecordReaderOp);
} // namespace tensorflow
......@@ -68,5 +68,7 @@ class IdentityReaderOp : public ReaderOpKernel {
REGISTER_KERNEL_BUILDER(Name("IdentityReader").Device(DEVICE_CPU),
IdentityReaderOp);
REGISTER_KERNEL_BUILDER(Name("IdentityReaderV2").Device(DEVICE_CPU),
IdentityReaderOp);
} // namespace tensorflow
......@@ -276,7 +276,11 @@ Status PaddingFIFOQueue::CompatibleNodeDefShapes(
}
Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PaddingFIFOQueue"));
if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
!MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
node_def.op());
}
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
......
......@@ -67,5 +67,7 @@ class PaddingFIFOQueueOp : public TypedQueueOp {
REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueue").Device(DEVICE_CPU),
PaddingFIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueueV2").Device(DEVICE_CPU),
PaddingFIFOQueueOp);
} // namespace tensorflow
......@@ -377,7 +377,11 @@ void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
}
Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PriorityQueue"));
if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
!MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
return errors::InvalidArgument("Expected PriorityQueue, found ",
node_def.op());
}
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
......
......@@ -64,5 +64,7 @@ class PriorityQueueOp : public TypedQueueOp {
REGISTER_KERNEL_BUILDER(Name("PriorityQueue").Device(DEVICE_CPU),
PriorityQueueOp);
REGISTER_KERNEL_BUILDER(Name("PriorityQueueV2").Device(DEVICE_CPU),
PriorityQueueOp);
} // namespace tensorflow
......@@ -33,8 +33,13 @@ class QueueOpKernel : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
QueueInterface* queue;
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
callback);
}
ComputeAsync(ctx, queue, [callback, queue]() {
queue->Unref();
callback();
......@@ -77,7 +82,12 @@ class EnqueueOp : public QueueAccessOpKernel {
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
DataTypeVector expected_inputs = {DT_STRING_REF};
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
......@@ -101,6 +111,7 @@ class EnqueueOp : public QueueAccessOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
// Defines an EnqueueManyOp, the execution of which slices each
// component of a tuple of tensors along the 0th dimension, and
......@@ -123,7 +134,12 @@ class EnqueueManyOp : public QueueAccessOpKernel {
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
DataTypeVector expected_inputs = {DT_STRING_REF};
DataTypeVector expected_inputs;
if (ctx->input_dtype(0) == DT_RESOURCE) {
expected_inputs.push_back(DT_RESOURCE);
} else {
expected_inputs.push_back(DT_STRING_REF);
}
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
......@@ -150,6 +166,8 @@ class EnqueueManyOp : public QueueAccessOpKernel {
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
EnqueueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
EnqueueManyOp);
// Defines a DequeueOp, the execution of which dequeues a tuple of
// tensors from the given Queue.
......@@ -166,9 +184,15 @@ class DequeueOp : public QueueAccessOpKernel {
protected:
void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
DoneCallback callback) override {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
callback);
}
queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
......@@ -192,6 +216,7 @@ class DequeueOp : public QueueAccessOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
// Defines a DequeueManyOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
......@@ -220,9 +245,17 @@ class DequeueManyOp : public QueueAccessOpKernel {
num_elements, " < 0 elements"),
callback);
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, false /* allow_small_batch */,
......@@ -250,6 +283,8 @@ class DequeueManyOp : public QueueAccessOpKernel {
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
DequeueManyOp);
// Defines a DequeueUpToOp, the execution of which concatenates the
// requested number of elements from the given Queue along the 0th
......@@ -296,9 +331,17 @@ class DequeueUpToOp : public QueueAccessOpKernel {
num_elements, " < 0 elements"),
callback);
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
if (ctx->input_dtype(0) == DT_RESOURCE) {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_RESOURCE, DT_INT32},
queue->component_dtypes()),
callback);
} else {
OP_REQUIRES_OK_ASYNC(ctx,
ctx->MatchSignature({DT_STRING_REF, DT_INT32},
queue->component_dtypes()),
callback);
}
queue->TryDequeueMany(
num_elements, ctx, true /* allow_small_batch */,
......@@ -326,6 +369,8 @@ class DequeueUpToOp : public QueueAccessOpKernel {
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
DequeueUpToOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
DequeueUpToOp);
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
......@@ -351,6 +396,7 @@ class QueueCloseOp : public QueueOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
// Defines a QueueSizeOp, which computes the number of elements in the
// given Queue, and emits it as an output tensor.
......@@ -377,5 +423,28 @@ class QueueSizeOp : public QueueOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
class FakeQueueOp : public OpKernel {
public:
explicit FakeQueueOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
}
void Compute(OpKernelContext* context) {
ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
}
private:
mutex mu_;
PersistentTensor handle_;
};
REGISTER_KERNEL_BUILDER(Name("FakeQueue").Device(DEVICE_CPU), FakeQueueOp);
} // namespace tensorflow
......@@ -425,7 +425,11 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
}
Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() &&
!MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) {
return errors::InvalidArgument("Expected RandomShuffleQueue, found ",
node_def.op());
}
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
int32 min_after_dequeue = -1;
......@@ -497,5 +501,7 @@ class RandomShuffleQueueOp : public TypedQueueOp {
REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU),
RandomShuffleQueueOp);
REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU),
RandomShuffleQueueOp);
} // namespace tensorflow
......@@ -97,6 +97,7 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp);
class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
public:
......@@ -149,6 +150,8 @@ class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU),
ReaderReadUpToOp);
REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2").Device(DEVICE_CPU),
ReaderReadUpToOp);
class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
public:
......@@ -165,6 +168,8 @@ class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
ReaderNumRecordsProducedOp);
REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2").Device(DEVICE_CPU),
ReaderNumRecordsProducedOp);
class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
public:
......@@ -181,6 +186,9 @@ class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
ReaderNumWorkUnitsCompletedOp);
REGISTER_KERNEL_BUILDER(
Name("ReaderNumWorkUnitsCompletedV2").Device(DEVICE_CPU),
ReaderNumWorkUnitsCompletedOp);
class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
public:
......@@ -198,6 +206,8 @@ class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
ReaderSerializeStateOp);
REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2").Device(DEVICE_CPU),
ReaderSerializeStateOp);
class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
public:
......@@ -217,6 +227,8 @@ class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
ReaderRestoreStateOp);
REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2").Device(DEVICE_CPU),
ReaderRestoreStateOp);
class ReaderResetOp : public ReaderVerbSyncOpKernel {
public:
......@@ -229,5 +241,7 @@ class ReaderResetOp : public ReaderVerbSyncOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp);
REGISTER_KERNEL_BUILDER(Name("ReaderResetV2").Device(DEVICE_CPU),
ReaderResetOp);
} // namespace tensorflow
......@@ -111,5 +111,7 @@ class TextLineReaderOp : public ReaderOpKernel {
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
TextLineReaderOp);
REGISTER_KERNEL_BUILDER(Name("TextLineReaderV2").Device(DEVICE_CPU),
TextLineReaderOp);
} // namespace tensorflow
......@@ -97,5 +97,7 @@ class TFRecordReaderOp : public ReaderOpKernel {
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
TFRecordReaderOp);
REGISTER_KERNEL_BUILDER(Name("TFRecordReaderV2").Device(DEVICE_CPU),
TFRecordReaderOp);
} // namespace tensorflow
......@@ -96,6 +96,8 @@ class WholeFileReaderOp : public ReaderOpKernel {
REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU),
WholeFileReaderOp);
REGISTER_KERNEL_BUILDER(Name("WholeFileReaderV2").Device(DEVICE_CPU),
WholeFileReaderOp);
class ReadFileOp : public OpKernel {
public:
......
......@@ -251,6 +251,41 @@ shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("RandomShuffleQueueV2")
.Output("handle: resource")
.Attr("component_types: list(type) >= 1")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("capacity: int = -1")
.Attr("min_after_dequeue: int = 0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A queue that randomizes the order of elements.
handle: The handle to the queue.
component_types: The type of each component in a value.
shapes: The shape of each component in a value. The length of this attr must
be either 0 or the same as the length of component_types. If the length of
this attr is 0, the shapes of queue elements are not constrained, and
only one element may be dequeued at a time.
capacity: The upper bound on the number of elements in this queue.
Negative numbers mean no limit.
min_after_dequeue: Dequeue will block unless there would be this
many elements after the dequeue or the queue is closed. This
ensures a minimum level of mixing of elements.
seed: If either seed or seed2 is set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, a random seed is used.
seed2: A second seed to avoid seed collision.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("FIFOQueue")
.Output("handle: Ref(string)")
.Attr("component_types: list(type) >= 1")
......@@ -277,6 +312,32 @@ shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("FIFOQueueV2")
.Output("handle: resource")
.Attr("component_types: list(type) >= 1")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("capacity: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A queue that produces elements in first-in first-out order.
handle: The handle to the queue.
component_types: The type of each component in a value.
shapes: The shape of each component in a value. The length of this attr must
be either 0 or the same as the length of component_types. If the length of
this attr is 0, the shapes of queue elements are not constrained, and
only one element may be dequeued at a time.
capacity: The upper bound on the number of elements in this queue.
Negative numbers mean no limit.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("PaddingFIFOQueue")
.Output("handle: Ref(string)")
.Attr("component_types: list(type) >= 1")
......@@ -311,6 +372,40 @@ shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("PaddingFIFOQueueV2")
.Output("handle: resource")
.Attr("component_types: list(type) >= 1")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("capacity: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A queue that produces elements in first-in first-out order.
Variable-size shapes are allowed by setting the corresponding shape dimensions
to 0 in the shape attr. In this case DequeueMany will pad up to the maximum
size of any given element in the minibatch. See below for details.
handle: The handle to the queue.
component_types: The type of each component in a value.
shapes: The shape of each component in a value. The length of this attr must
be either 0 or the same as the length of component_types.
Shapes of fixed rank but variable size are allowed by setting
any shape dimension to -1. In this case, the inputs' shape may vary along
the given dimension, and DequeueMany will pad the given dimension with
zeros up to the maximum shape of all elements in the given batch.
If the length of this attr is 0, different queue elements may have
different ranks and shapes, but only one element may be dequeued at a time.
capacity: The upper bound on the number of elements in this queue.
Negative numbers mean no limit.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("PriorityQueue")
.Output("handle: Ref(string)")
.Attr("component_types: list(type) >= 0 = []")
......@@ -343,6 +438,45 @@ shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("PriorityQueueV2")
.Output("handle: resource")
.Attr("component_types: list(type) >= 0 = []")
.Attr("shapes: list(shape) >= 0")
.Attr("capacity: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A queue that produces elements sorted by the first component value.
Note that the PriorityQueue requires the first component of any element
to be a scalar int64, in addition to the other elements declared by
component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue
and DequeueMany) on a PriorityQueue will all require (resp. output) one extra
entry in their input (resp. output) lists.
handle: The handle to the queue.
component_types: The type of each component in a value.
shapes: The shape of each component in a value. The length of this attr must
be either 0 or the same as the length of component_types. If the length of
this attr is 0, the shapes of queue elements are not constrained, and
only one element may be dequeued at a time.
capacity: The upper bound on the number of elements in this queue.
Negative numbers mean no limit.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
REGISTER_OP("FakeQueue")
.Input("resource: resource")
.Output("handle: Ref(string)")
.SetIsStateful()
.SetShapeFn(TwoElementOutput)
.Doc("Deprecated. Do not use.");
REGISTER_OP("QueueEnqueue")
.Input("handle: Ref(string)")
.Input("components: Tcomponents")
......@@ -365,6 +499,28 @@ timeout_ms: If the queue is full, this operation will block for up to
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueEnqueueV2")
.Input("handle: resource")
.Input("components: Tcomponents")
.Attr("Tcomponents: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Enqueues a tuple of one or more tensors in the given queue.
The components input has k elements, which correspond to the components of
tuples stored in the given queue.
N.B. If the queue is full, this operation will block until the given
element has been enqueued (or 'timeout_ms' elapses, if specified).
handle: The handle to a queue.
components: One or more tensors from which the enqueued tensors should be taken.
timeout_ms: If the queue is full, this operation will block for up to
timeout_ms milliseconds.
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueEnqueueMany")
.Input("handle: Ref(string)")
.Input("components: Tcomponents")
......@@ -392,6 +548,33 @@ timeout_ms: If the queue is too full, this operation will block for up
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueEnqueueManyV2")
.Input("handle: resource")
.Input("components: Tcomponents")
.Attr("Tcomponents: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Enqueues zero or more tuples of one or more tensors in the given queue.
This operation slices each component tensor along the 0th dimension to
make multiple queue elements. All of the tuple components must have the
same size in the 0th dimension.
The components input has k elements, which correspond to the components of
tuples stored in the given queue.
N.B. If the queue is full, this operation will block until the given
elements have been enqueued (or 'timeout_ms' elapses, if specified).
handle: The handle to a queue.
components: One or more tensors from which the enqueued tensors should
be taken.
timeout_ms: If the queue is too full, this operation will block for up
to timeout_ms milliseconds.
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeue")
.Input("handle: Ref(string)")
.Output("components: component_types")
......@@ -416,6 +599,30 @@ timeout_ms: If the queue is empty, this operation will block for up to
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeueV2")
.Input("handle: resource")
.Output("components: component_types")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Dequeues a tuple of one or more tensors from the given queue.
This operation has k outputs, where k is the number of components
in the tuples stored in the given queue, and output i is the ith
component of the dequeued tuple.
N.B. If the queue is empty, this operation will block until an element
has been dequeued (or 'timeout_ms' elapses, if specified).
handle: The handle to a queue.
components: One or more tensors that were dequeued as a tuple.
component_types: The type of each component in a tuple.
timeout_ms: If the queue is empty, this operation will block for up to
timeout_ms milliseconds.
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeueMany")
.Input("handle: Ref(string)")
.Input("n: int32")
......@@ -449,6 +656,39 @@ timeout_ms: If the queue has fewer than n elements, this operation
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeueManyV2")
.Input("handle: resource")
.Input("n: int32")
.Output("components: component_types")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Dequeues n tuples of one or more tensors from the given queue.
If the queue is closed and there are fewer than n elements, then an
OutOfRange error is returned.
This operation concatenates queue-element component tensors along the
0th dimension to make a single component tensor. All of the components
in the dequeued tuple will have size n in the 0th dimension.
This operation has k outputs, where k is the number of components in
the tuples stored in the given queue, and output i is the ith
component of the dequeued tuple.
N.B. If the queue is empty, this operation will block until n elements
have been dequeued (or 'timeout_ms' elapses, if specified).
handle: The handle to a queue.
n: The number of tuples to dequeue.
components: One or more tensors that were dequeued as a tuple.
component_types: The type of each component in a tuple.
timeout_ms: If the queue has fewer than n elements, this operation
will block for up to timeout_ms milliseconds.
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeueUpTo")
.Input("handle: Ref(string)")
.Input("n: int32")
......@@ -486,6 +726,43 @@ timeout_ms: If the queue has fewer than n elements, this operation
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueDequeueUpToV2")
.Input("handle: resource")
.Input("n: int32")
.Output("components: component_types")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Dequeues n tuples of one or more tensors from the given queue.
This operation is not supported by all queues. If a queue does not support
DequeueUpTo, then an Unimplemented error is returned.
If the queue is closed and there are more than 0 but less than n elements
remaining, then instead of returning an OutOfRange error like
QueueDequeueMany, less than `n` elements are returned immediately. If the queue
is closed and there are 0 elements left in the queue, then an OutOfRange
error is returned just like in QueueDequeueMany. Otherwise the behavior
is identical to QueueDequeueMany:
This operation concatenates queue-element component tensors along the
0th dimension to make a single component tensor. All of the components
in the dequeued tuple will have size n in the 0th dimension.
This operation has k outputs, where k is the number of components in
the tuples stored in the given queue, and output i is the ith
component of the dequeued tuple.
handle: The handle to a queue.
n: The number of tuples to dequeue.
components: One or more tensors that were dequeued as a tuple.
component_types: The type of each component in a tuple.
timeout_ms: If the queue has fewer than n elements, this operation
will block for up to timeout_ms milliseconds.
Note: This option is not supported yet.
)doc");
REGISTER_OP("QueueClose")
.Input("handle: Ref(string)")
.SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
......@@ -504,6 +781,24 @@ cancel_pending_enqueues: If true, all pending enqueue requests that are
blocked on the given queue will be cancelled.
)doc");
REGISTER_OP("QueueCloseV2")
.Input("handle: resource")
.SetShapeFn(shape_inference::NoOutputs)
.Attr("cancel_pending_enqueues: bool = false")
.Doc(R"doc(
Closes the given queue.
This operation signals that no more elements will be enqueued in the
given queue. Subsequent Enqueue(Many) operations will fail.
Subsequent Dequeue(Many) operations will continue to succeed if
sufficient elements remain in the queue. Subsequent Dequeue(Many)
operations that would block will fail immediately.
handle: The handle to a queue.
cancel_pending_enqueues: If true, all pending enqueue requests that are
blocked on the given queue will be cancelled.
)doc");
REGISTER_OP("QueueSize")
.Input("handle: Ref(string)")
.Output("size: int32")
......@@ -515,6 +810,17 @@ handle: The handle to a queue.
size: The number of elements in the given queue.
)doc");
REGISTER_OP("QueueSizeV2")
.Input("handle: resource")
.Output("size: int32")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes the number of elements in the given queue.
handle: The handle to a queue.
size: The number of elements in the given queue.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("AccumulatorNumAccumulated")
......
......@@ -378,6 +378,25 @@ shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("WholeFileReaderV2")
.Output("reader_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs the entire contents of a file as a value.
To use, enqueue filenames in a Queue. The output of ReaderRead will
be a filename (key) and the contents of that file (value).
reader_handle: The handle to reference the Reader.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("TextLineReader")
.Output("reader_handle: Ref(string)")
.Attr("skip_header_lines: int = 0")
......@@ -396,6 +415,24 @@ shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("TextLineReaderV2")
.Output("reader_handle: resource")
.Attr("skip_header_lines: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
reader_handle: The handle to reference the Reader.
skip_header_lines: Number of lines to skip from the beginning of every file.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("FixedLengthRecordReader")
.Output("reader_handle: Ref(string)")
.Attr("header_bytes: int = 0")
......@@ -415,6 +452,25 @@ shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("FixedLengthRecordReaderV2")
.Output("reader_handle: resource")
.Attr("header_bytes: int = 0")
.Attr("record_bytes: int")
.Attr("footer_bytes: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs fixed-length records from a file.
reader_handle: The handle to reference the Reader.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("TFRecordReader")
.Output("reader_handle: Ref(string)")
.Attr("container: string = ''")
......@@ -432,6 +488,23 @@ shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("TFRecordReaderV2")
.Output("reader_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("compression_type: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs the records from a TensorFlow Records file.
reader_handle: The handle to reference the Reader.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("IdentityReader")
.Output("reader_handle: Ref(string)")
.Attr("container: string = ''")
......@@ -451,6 +524,25 @@ shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
REGISTER_OP("IdentityReaderV2")
.Output("reader_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs the queued work as both the key and value.
To use, enqueue strings in a Queue. ReaderRead will take the front
work string and output (work, work).
reader_handle: The handle to reference the Reader.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc");
// Ops that operate on Readers ------------------------------------------------
REGISTER_OP("ReaderRead")
......@@ -472,6 +564,25 @@ key: A scalar.
value: A scalar.
)doc");
REGISTER_OP("ReaderReadV2")
.Input("reader_handle: resource")
.Input("queue_handle: resource")
.Output("key: string")
.Output("value: string")
.SetShapeFn(ScalarInputsAndOutputs)
.Doc(R"doc(
Returns the next record (key, value pair) produced by a Reader.
Will dequeue from the input queue if necessary (e.g. when the
Reader needs to start reading from a new file since it has finished
with the previous file).
reader_handle: Handle to a Reader.
queue_handle: Handle to a Queue, with string work items.
key: A scalar.
value: A scalar.
)doc");
REGISTER_OP("ReaderReadUpTo")
.Input("reader_handle: Ref(string)")
.Input("queue_handle: Ref(string)")
......@@ -503,6 +614,37 @@ keys: A 1-D tensor.
values: A 1-D tensor.
)doc");
REGISTER_OP("ReaderReadUpToV2")
.Input("reader_handle: resource")
.Input("queue_handle: resource")
.Input("num_records: int64")
.Output("keys: string")
.Output("values: string")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, out);
c->set_output(1, out);
return Status::OK();
})
.Doc(R"doc(
Returns up to `num_records` (key, value) pairs produced by a Reader.
Will dequeue from the input queue if necessary (e.g. when the
Reader needs to start reading from a new file since it has finished
with the previous file).
It may return less than `num_records` even before the last batch.
reader_handle: Handle to a `Reader`.
queue_handle: Handle to a `Queue`, with string work items.
num_records: number of records to read from `Reader`.
keys: A 1-D tensor.
values: A 1-D tensor.
)doc");
REGISTER_OP("ReaderNumRecordsProduced")
.Input("reader_handle: Ref(string)")
.Output("records_produced: int64")
......@@ -516,6 +658,19 @@ succeeded.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderNumRecordsProducedV2")
.Input("reader_handle: resource")
.Output("records_produced: int64")
.SetShapeFn(ScalarInputsAndOutputs)
.Doc(R"doc(
Returns the number of records this Reader has produced.
This is the same as the number of ReaderRead executions that have
succeeded.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderNumWorkUnitsCompleted")
.Input("reader_handle: Ref(string)")
.Output("units_completed: int64")
......@@ -526,6 +681,16 @@ Returns the number of work units this Reader has finished processing.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
.Input("reader_handle: resource")
.Output("units_completed: int64")
.SetShapeFn(ScalarInputsAndOutputs)
.Doc(R"doc(
Returns the number of work units this Reader has finished processing.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderSerializeState")
.Input("reader_handle: Ref(string)")
.Output("state: string")
......@@ -539,6 +704,19 @@ Unimplemented error.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderSerializeStateV2")
.Input("reader_handle: resource")
.Output("state: string")
.SetShapeFn(ScalarInputsAndOutputs)
.Doc(R"doc(
Produce a string tensor that encodes the state of a Reader.
Not all Readers support being serialized, so this can produce an
Unimplemented error.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderRestoreState")
.Input("reader_handle: Ref(string)")
.Input("state: string")
......@@ -563,6 +741,26 @@ state: Result of a ReaderSerializeState of a Reader with type
matching reader_handle.
)doc");
REGISTER_OP("ReaderRestoreStateV2")
.Input("reader_handle: resource")
.Input("state: string")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return Status::OK();
})
.Doc(R"doc(
Restore a reader to a previously saved state.
Not all Readers support being restored, so this can produce an
Unimplemented error.
reader_handle: Handle to a Reader.
state: Result of a ReaderSerializeState of a Reader with type
matching reader_handle.
)doc");
REGISTER_OP("ReaderReset")
.Input("reader_handle: Ref(string)")
.SetShapeFn(TwoElementVectorAndScalarOutputs)
......@@ -572,6 +770,15 @@ Restore a Reader to its initial clean state.
reader_handle: Handle to a Reader.
)doc");
REGISTER_OP("ReaderResetV2")
.Input("reader_handle: resource")
.SetShapeFn(ScalarInputsAndOutputs)
.Doc(R"doc(
Restore a Reader to its initial clean state.
reader_handle: Handle to a Reader.
)doc");
// Other input Ops ----------------------------------------------------------
REGISTER_OP("ReadFile")
......
......@@ -297,6 +297,8 @@ class DType(object):
@property
def size(self):
if self._type_enum == types_pb2.DT_RESOURCE:
return 1
return np.dtype(self.as_numpy_dtype).itemsize
# Define data type range of numpy dtype
......
......@@ -45,7 +45,7 @@ class FIFOQueueTest(test.TestCase):
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'FIFOQueue'
name:'Q' op:'FIFOQueueV2'
attr { key: 'component_types' value { list { type: DT_FLOAT } } }
attr { key: 'shapes' value { list {} } }
attr { key: 'capacity' value { i: 10 } }
......@@ -61,7 +61,7 @@ class FIFOQueueTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'FIFOQueue'
name:'Q' op:'FIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......@@ -80,7 +80,7 @@ class FIFOQueueTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'FIFOQueue'
name:'Q' op:'FIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......@@ -1184,44 +1184,44 @@ class FIFOQueueTest(test.TestCase):
with self.test_session():
q_a_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_a")
q_a_2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32, shared_name="q_a")
q_a_1.queue_ref.eval()
q_a_1.queue_ref.op.run()
with self.assertRaisesOpError("capacity"):
q_a_2.queue_ref.eval()
q_a_2.queue_ref.op.run()
q_b_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_b")
q_b_2 = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shared_name="q_b")
q_b_1.queue_ref.eval()
q_b_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_b_2.queue_ref.eval()
q_b_2.queue_ref.op.run()
q_c_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_c")
q_c_2 = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_c")
q_c_1.queue_ref.eval()
q_c_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_c_2.queue_ref.eval()
q_c_2.queue_ref.op.run()
q_d_1 = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d")
q_d_2 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_d")
q_d_1.queue_ref.eval()
q_d_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_d_2.queue_ref.eval()
q_d_2.queue_ref.op.run()
q_e_1 = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e")
q_e_2 = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_e")
q_e_1.queue_ref.eval()
q_e_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_e_2.queue_ref.eval()
q_e_2.queue_ref.op.run()
q_f_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_f")
q_f_2 = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shared_name="q_f")
q_f_1.queue_ref.eval()
q_f_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_f_2.queue_ref.eval()
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
with self.test_session():
......@@ -1241,7 +1241,7 @@ class FIFOQueueTest(test.TestCase):
q1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
q2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32)
enq_q = data_flow_ops.FIFOQueue.from_list(3, [q1, q2])
with self.assertRaisesOpError("Index must be in the range"):
with self.assertRaisesOpError("is not in"):
enq_q.dequeue().eval()
def _blockingDequeue(self, sess, dequeue_op):
......@@ -1389,16 +1389,6 @@ class FIFOQueueTest(test.TestCase):
for (input_elem, output_elem) in zip(input_tuple, output_tuple):
self.assertAllEqual(input_elem, output_elem)
def testDeviceColocation(self):
with ops.device("/job:ps"):
q = data_flow_ops.FIFOQueue(32, [dtypes_lib.int32], name="q")
with ops.device("/job:worker/task:7"):
dequeued_t = q.dequeue()
self.assertDeviceEqual("/job:ps", dequeued_t.device)
self.assertEqual([b"loc:@q"], dequeued_t.op.colocation_groups())
class FIFOQueueDictTest(test.TestCase):
......@@ -1411,7 +1401,7 @@ class FIFOQueueDictTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'FIFOQueue'
name:'Q' op:'FIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......@@ -1432,7 +1422,7 @@ class FIFOQueueDictTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'FIFOQueue'
name:'Q' op:'FIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......
......@@ -42,7 +42,7 @@ class PaddingFIFOQueueTest(test.TestCase):
10, dtypes_lib.float32, ((None,),), name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'PaddingFIFOQueue'
name:'Q' op:'PaddingFIFOQueueV2'
attr { key: 'component_types' value { list { type: DT_FLOAT } } }
attr { key: 'shapes' value { list { shape { dim { size: -1 } } } } }
attr { key: 'capacity' value { i: 10 } }
......@@ -58,7 +58,7 @@ class PaddingFIFOQueueTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'PaddingFIFOQueue'
name:'Q' op:'PaddingFIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......@@ -77,7 +77,7 @@ class PaddingFIFOQueueTest(test.TestCase):
name="Q")
self.assertTrue(isinstance(q.queue_ref, ops.Tensor))
self.assertProtoEquals("""
name:'Q' op:'PaddingFIFOQueue'
name:'Q' op:'PaddingFIFOQueueV2'
attr { key: 'component_types' value { list {
type: DT_INT32 type : DT_FLOAT
} } }
......@@ -1307,50 +1307,50 @@ class PaddingFIFOQueueTest(test.TestCase):
10, dtypes_lib.float32, ((),), shared_name="q_a")
q_a_2 = data_flow_ops.PaddingFIFOQueue(
15, dtypes_lib.float32, ((),), shared_name="q_a")
q_a_1.queue_ref.eval()
q_a_1.queue_ref.op.run()
with self.assertRaisesOpError("capacity"):
q_a_2.queue_ref.eval()
q_a_2.queue_ref.op.run()
q_b_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_b")
q_b_2 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.int32, ((),), shared_name="q_b")
q_b_1.queue_ref.eval()
q_b_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_b_2.queue_ref.eval()
q_b_2.queue_ref.op.run()
q_c_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_c")
q_c_2 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_c")
q_c_1.queue_ref.eval()
q_c_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_c_2.queue_ref.eval()
q_c_2.queue_ref.op.run()
q_d_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d")
q_d_2 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_d")
q_d_1.queue_ref.eval()
q_d_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_d_2.queue_ref.eval()
q_d_2.queue_ref.op.run()
q_e_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e")
q_e_2 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_e")
q_e_1.queue_ref.eval()
q_e_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_e_2.queue_ref.eval()
q_e_2.queue_ref.op.run()
q_f_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_f")
q_f_2 = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), ((), ()),
shared_name="q_f")
q_f_1.queue_ref.eval()
q_f_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_f_2.queue_ref.eval()
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
with self.test_session():
......@@ -1371,7 +1371,7 @@ class PaddingFIFOQueueTest(test.TestCase):
q1 = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
q2 = data_flow_ops.PaddingFIFOQueue(15, dtypes_lib.float32, ((),))
enq_q = data_flow_ops.PaddingFIFOQueue.from_list(3, [q1, q2])
with self.assertRaisesOpError("Index must be in the range"):
with self.assertRaisesOpError("is not in"):
enq_q.dequeue().eval()
def _blockingDequeue(self, sess, dequeue_op):
......
......@@ -1125,65 +1125,65 @@ class RandomShuffleQueueTest(test.TestCase):
10, 5, dtypes_lib.float32, shared_name="q_a")
q_a_2 = data_flow_ops.RandomShuffleQueue(
15, 5, dtypes_lib.float32, shared_name="q_a")
q_a_1.queue_ref.eval()
q_a_1.queue_ref.op.run()
with self.assertRaisesOpError("capacity"):
q_a_2.queue_ref.eval()
q_a_2.queue_ref.op.run()
q_b_1 = data_flow_ops.RandomShuffleQueue(
10, 0, dtypes_lib.float32, shared_name="q_b")
q_b_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_b")
q_b_1.queue_ref.eval()
q_b_1.queue_ref.op.run()
with self.assertRaisesOpError("min_after_dequeue"):
q_b_2.queue_ref.eval()
q_b_2.queue_ref.op.run()
q_c_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_c")
q_c_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.int32, shared_name="q_c")
q_c_1.queue_ref.eval()
q_c_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_c_2.queue_ref.eval()
q_c_2.queue_ref.op.run()
q_d_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_d")
q_d_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d")
q_d_1.queue_ref.eval()
q_d_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_d_2.queue_ref.eval()
q_d_2.queue_ref.op.run()
q_e_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e")
q_e_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_e")
q_e_1.queue_ref.eval()
q_e_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_e_2.queue_ref.eval()
q_e_2.queue_ref.op.run()
q_f_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 3)], shared_name="q_f")
q_f_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shapes=[(1, 1, 2, 4)], shared_name="q_f")
q_f_1.queue_ref.eval()
q_f_1.queue_ref.op.run()
with self.assertRaisesOpError("component shapes"):
q_f_2.queue_ref.eval()
q_f_2.queue_ref.op.run()
q_g_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_g")
q_g_2 = data_flow_ops.RandomShuffleQueue(
10, 5, (dtypes_lib.float32, dtypes_lib.int32), shared_name="q_g")
q_g_1.queue_ref.eval()
q_g_1.queue_ref.op.run()
with self.assertRaisesOpError("component types"):
q_g_2.queue_ref.eval()
q_g_2.queue_ref.op.run()
q_h_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, seed=12, shared_name="q_h")
q_h_2 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, seed=21, shared_name="q_h")
q_h_1.queue_ref.eval()
q_h_1.queue_ref.op.run()
with self.assertRaisesOpError("random seeds"):
q_h_2.queue_ref.eval()
q_h_2.queue_ref.op.run()
def testSelectQueue(self):
with self.test_session():
......@@ -1204,7 +1204,7 @@ class RandomShuffleQueueTest(test.TestCase):
q1 = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
q2 = data_flow_ops.RandomShuffleQueue(15, 0, dtypes_lib.float32)
enq_q = data_flow_ops.RandomShuffleQueue.from_list(3, [q1, q2])
with self.assertRaisesOpError("Index must be in the range"):
with self.assertRaisesOpError("is not in"):
enq_q.dequeue().eval()
def _blockingDequeue(self, sess, dequeue_op):
......
......@@ -205,8 +205,8 @@ class QueueBase(object):
reduced_shapes = [
six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)]
queue_refs = [x.queue_ref for x in queues]
selected_queue = control_flow_ops.ref_select(index, queue_refs)
queue_refs = array_ops.stack([x.queue_ref for x in queues])
selected_queue = array_ops.gather(queue_refs, index)
return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names,
queue_ref=selected_queue)
......@@ -326,7 +326,12 @@ class QueueBase(object):
for val, shape in zip(vals, self._shapes):
val.get_shape().assert_is_compatible_with(shape)
return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope)
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_enqueue_v2(
self._queue_ref, vals, name=scope)
else:
return gen_data_flow_ops._queue_enqueue(
self._queue_ref, vals, name=scope)
def enqueue_many(self, vals, name=None):
"""Enqueues zero or more elements to this queue.
......@@ -367,7 +372,7 @@ class QueueBase(object):
val.get_shape().with_rank_at_least(1)[0])
val.get_shape()[1:].assert_is_compatible_with(shape)
return gen_data_flow_ops._queue_enqueue_many(
return gen_data_flow_ops._queue_enqueue_many_v2(
self._queue_ref, vals, name=scope)
def _dequeue_return_value(self, tensors):
......@@ -415,8 +420,12 @@ class QueueBase(object):
"""
if name is None:
name = "%s_Dequeue" % self._name
ret = gen_data_flow_ops._queue_dequeue(
self._queue_ref, self._dtypes, name=name)
if self._queue_ref.dtype == _dtypes.resource:
ret = gen_data_flow_ops._queue_dequeue_v2(
self._queue_ref, self._dtypes, name=name)
else:
ret = gen_data_flow_ops._queue_dequeue(
self._queue_ref, self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
......@@ -454,7 +463,7 @@ class QueueBase(object):
if name is None:
name = "%s_DequeueMany" % self._name
ret = gen_data_flow_ops._queue_dequeue_many(
ret = gen_data_flow_ops._queue_dequeue_many_v2(
self._queue_ref, n=n, component_types=self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
......@@ -495,7 +504,7 @@ class QueueBase(object):
if name is None:
name = "%s_DequeueUpTo" % self._name
ret = gen_data_flow_ops._queue_dequeue_up_to(
ret = gen_data_flow_ops._queue_dequeue_up_to_v2(
self._queue_ref, n=n, component_types=self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
......@@ -529,9 +538,14 @@ class QueueBase(object):
"""
if name is None:
name = "%s_Close" % self._name
return gen_data_flow_ops._queue_close(
self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_close_v2(
self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
else:
return gen_data_flow_ops._queue_close(
self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
def size(self, name=None):
"""Compute the number of elements in this queue.
......@@ -544,7 +558,10 @@ class QueueBase(object):
"""
if name is None:
name = "%s_Size" % self._name
return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name)
else:
return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
class RandomShuffleQueue(QueueBase):
......@@ -614,7 +631,7 @@ class RandomShuffleQueue(QueueBase):
# the id of the last op created.)
string = (str(seed1) + shared_name).encode("utf-8")
seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
queue_ref = gen_data_flow_ops._random_shuffle_queue(
queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
shared_name=shared_name, name=name)
......@@ -665,7 +682,7 @@ class FIFOQueue(QueueBase):
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = gen_data_flow_ops._fifo_queue(
queue_ref = gen_data_flow_ops._fifo_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
......@@ -732,7 +749,7 @@ class PaddingFIFOQueue(QueueBase):
"but received %d dtypes and %d shapes."
% (len(dtypes), len(shapes)))
queue_ref = gen_data_flow_ops._padding_fifo_queue(
queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
......@@ -788,7 +805,7 @@ class PriorityQueue(QueueBase):
types = _as_type_list(types)
shapes = _as_shape_list(shapes, types)
queue_ref = gen_data_flow_ops._priority_queue(
queue_ref = gen_data_flow_ops._priority_queue_v2(
component_types=types, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
......
......@@ -58,8 +58,12 @@ BarrierIncompleteSize
BarrierInsertMany
BarrierReadySize
BarrierTakeMany
PriorityQueue
DeleteSessionTensor
FakeQueue
FIFOQueue
FIFOQueueV2
GetSessionHandle
GetSessionTensor
HashTable
InitializeTable
InitializeTableFromTextFile
......@@ -75,41 +79,52 @@ Mutex
MutexAcquire
MutexRelease
PaddingFIFOQueue
PaddingFIFOQueueV2
PriorityQueue
PriorityQueueV2
QueueClose
QueueCloseV2
QueueDequeue
QueueDequeueV2
QueueDequeueMany
QueueDequeueManyV2
QueueDequeueUpTo
QueueDequeueUpToV2
QueueEnqueue
QueueEnqueueV2
QueueEnqueueMany
QueueEnqueueManyV2
QueueSize
QueueSizeV2
RandomShuffleQueue
RandomShuffleQueueV2
Stack
StackClose
StackPop
StackPush
StackClose
TensorArray
TensorArrayClose
TensorArrayConcat
TensorArrayGather
TensorArrayGrad
TensorArrayRead
TensorArrayPack
TensorArrayScatter
TensorArraySize
TensorArraySplit
TensorArrayUnpack
TensorArrayWrite
TensorArrayV2
TensorArrayCloseV2
TensorArrayConcat
TensorArrayConcatV2
TensorArrayGather
TensorArrayGatherV2
TensorArrayGrad
TensorArrayGradV2
TensorArrayReadV2
TensorArrayPack
TensorArrayPackV2
TensorArrayRead
TensorArrayReadV2
TensorArrayScatter
TensorArrayScatterV2
TensorArraySize
TensorArraySizeV2
TensorArraySplit
TensorArraySplitV2
TensorArrayUnpack
TensorArrayUnpackV2
TensorArrayV2
TensorArrayWrite
TensorArrayWriteV2
TensorArrayV3
TensorArrayCloseV3
......@@ -123,9 +138,6 @@ TensorArraySizeV3
TensorArraySplitV3
TensorArrayUnpackV3
TensorArrayWriteV3
GetSessionHandle
GetSessionTensor
DeleteSessionTensor
# functional_ops
SymbolicGradient
......@@ -150,6 +162,18 @@ ReaderReset
ReaderRestoreState
ReaderSerializeState
ReaderWorkQueueLength
FixedLengthRecordReaderV2
IdentityReaderV2
ReaderCloseV2
ReaderEnqueueWorkV2
ReaderNumRecordsProducedV2
ReaderNumWorkUnitsCompletedV2
ReaderReadV2
ReaderReadUpToV2
ReaderResetV2
ReaderRestoreStateV2
ReaderSerializeStateV2
ReaderWorkQueueLengthV2
Restore
RestoreSlice
Save
......@@ -159,6 +183,9 @@ ShardedFilespec
TextLineReader
TFRecordReader
WholeFileReader
TextLineReaderV2
TFRecordReaderV2
WholeFileReaderV2
# linalg_ops
BatchCholesky
......
......@@ -150,6 +150,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_io_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
......@@ -267,7 +268,13 @@ class ReaderBase(object):
queue_ref = queue
else:
queue_ref = queue.queue_ref
return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name)
else:
# For compatibility with pre-resource queues, create a ref(string) tensor
# which can be looked up as the same queue by a resource manager.
old_queue_op = gen_data_flow_ops._fake_queue(queue_ref)
return gen_io_ops._reader_read(self._reader_ref, old_queue_op, name=name)
def read_up_to(self, queue, num_records, # pylint: disable=invalid-name
name=None):
......@@ -293,10 +300,19 @@ class ReaderBase(object):
queue_ref = queue
else:
queue_ref = queue.queue_ref
return gen_io_ops._reader_read_up_to(self._reader_ref,
queue_ref,
num_records,
name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_read_up_to_v2(self._reader_ref,
queue_ref,
num_records,
name=name)
else:
# For compatibility with pre-resource queues, create a ref(string) tensor
# which can be looked up as the same queue by a resource manager.
old_queue_op = gen_data_flow_ops._fake_queue(queue_ref)
return gen_io_ops._reader_read_up_to_v2(self._reader_ref,
old_queue_op,
num_records,
name=name)
def num_records_produced(self, name=None):
"""Returns the number of records this reader has produced.
......@@ -311,7 +327,12 @@ class ReaderBase(object):
An int64 Tensor.
"""
return gen_io_ops._reader_num_records_produced(self._reader_ref, name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_num_records_produced_v2(self._reader_ref,
name=name)
else:
return gen_io_ops._reader_num_records_produced(self._reader_ref,
name=name)
def num_work_units_completed(self, name=None):
"""Returns the number of work units this reader has finished processing.
......@@ -322,8 +343,12 @@ class ReaderBase(object):
Returns:
An int64 Tensor.
"""
return gen_io_ops._reader_num_work_units_completed(self._reader_ref,
name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_num_work_units_completed_v2(self._reader_ref,
name=name)
else:
return gen_io_ops._reader_num_work_units_completed(self._reader_ref,
name=name)
def serialize_state(self, name=None):
"""Produce a string tensor that encodes the state of a reader.
......@@ -337,7 +362,10 @@ class ReaderBase(object):
Returns:
A string Tensor.
"""
return gen_io_ops._reader_serialize_state(self._reader_ref, name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_serialize_state_v2(self._reader_ref, name=name)
else:
return gen_io_ops._reader_serialize_state(self._reader_ref, name=name)
def restore_state(self, state, name=None):
"""Restore a reader to a previously saved state.
......@@ -353,7 +381,12 @@ class ReaderBase(object):
Returns:
The created Operation.
"""
return gen_io_ops._reader_restore_state(self._reader_ref, state, name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_restore_state_v2(
self._reader_ref, state, name=name)
else:
return gen_io_ops._reader_restore_state(
self._reader_ref, state, name=name)
@property
def supports_serialize(self):
......@@ -369,7 +402,10 @@ class ReaderBase(object):
Returns:
The created Operation.
"""
return gen_io_ops._reader_reset(self._reader_ref, name=name)
if self._reader_ref.dtype == dtypes.resource:
return gen_io_ops._reader_reset_v2(self._reader_ref, name=name)
else:
return gen_io_ops._reader_reset(self._reader_ref, name=name)
ops.NotDifferentiable("ReaderRead")
......@@ -396,7 +432,7 @@ class WholeFileReader(ReaderBase):
Args:
name: A name for the operation (optional).
"""
rr = gen_io_ops._whole_file_reader(name=name)
rr = gen_io_ops._whole_file_reader_v2(name=name)
super(WholeFileReader, self).__init__(rr, supports_serialize=True)
......@@ -419,8 +455,8 @@ class TextLineReader(ReaderBase):
to skip from the beginning of every file.
name: A name for the operation (optional).
"""
rr = gen_io_ops._text_line_reader(skip_header_lines=skip_header_lines,
name=name)
rr = gen_io_ops._text_line_reader_v2(skip_header_lines=skip_header_lines,
name=name)
super(TextLineReader, self).__init__(rr)
......@@ -444,7 +480,7 @@ class FixedLengthRecordReader(ReaderBase):
footer_bytes: An optional int. Defaults to 0.
name: A name for the operation (optional).
"""
rr = gen_io_ops._fixed_length_record_reader(
rr = gen_io_ops._fixed_length_record_reader_v2(
record_bytes=record_bytes, header_bytes=header_bytes,
footer_bytes=footer_bytes, name=name)
super(FixedLengthRecordReader, self).__init__(rr)
......@@ -470,7 +506,7 @@ class TFRecordReader(ReaderBase):
compression_type = python_io.TFRecordOptions.get_compression_type_string(
options)
rr = gen_io_ops._tf_record_reader(
rr = gen_io_ops._tf_record_reader_v2(
name=name, compression_type=compression_type)
super(TFRecordReader, self).__init__(rr)
......@@ -493,7 +529,7 @@ class IdentityReader(ReaderBase):
Args:
name: A name for the operation (optional).
"""
rr = gen_io_ops._identity_reader(name=name)
rr = gen_io_ops._identity_reader_v2(name=name)
super(IdentityReader, self).__init__(rr, supports_serialize=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册