From 5e176998d92a64d78df57e9fb78582e5e7e4ebb6 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Wed, 19 Oct 2016 15:49:24 -0800 Subject: [PATCH] Add PlaceholderV2 and VariableV2 ops that can properly specify scalar shapes during construction. Placholder/Variable currently treat [] as "unknown", because it predated scalar shapes. So we introduce new versions that are the same implementation underneath, but can differentiate between unknown and scalar. Change: 136659818 --- tensorflow/core/graph/subgraph.cc | 3 +- tensorflow/core/kernels/constant_op.cc | 4 + tensorflow/core/ops/array_ops.cc | 30 ++++++ tensorflow/core/ops/array_ops_test.cc | 48 +++++++++ .../python/kernel_tests/constant_op_test.py | 97 +++++++++++++++++++ tensorflow/python/ops/array_ops.py | 1 + 6 files changed, 182 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index c2978bbcf4a..58199140d28 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -96,7 +96,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, if (e->src_output() == id.second) { to_remove.emplace_back(e); } else if (e->src_output() == Graph::kControlSlot && - n->def().op() == "Placeholder") { + (n->def().op() == "Placeholder" || + n->def().op() == "PlaceholderV2")) { // When feeding a Placeholder node, any outgoing control edges // will be replaced with a control edge from the replacement // recv_node. diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index be7a7a41a41..e066cacfc2e 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -246,10 +246,14 @@ class PlaceholderOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_CPU), PlaceholderOp); +REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_CPU), + PlaceholderOp); // The following GPU kernel registration is used to address the situation that // a placeholder is added in a GPU device context and soft placement is false. // Since a placeholder should never be executed, adding these GPU kernels has // no effect on graph execution. REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_GPU), PlaceholderOp); +REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_GPU), + PlaceholderOp); } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 12a2fb23fb8..c47638ca8d8 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2409,6 +2409,36 @@ shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the shape is unconstrained. )doc"); +// This version fixes an issue with the original version of Placeholder +// where the empty shape attribute "[]" was used to denote +// an unknown shape. This meant that scalars (added later) could +// not be represented natively. This new version fixes that +// limitation. +REGISTER_OP("PlaceholderV2") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("shape: shape") + .SetShapeFn([](InferenceContext* c) { + TensorShapeProto shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + ShapeHandle output; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape, &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +A placeholder op for a value that will be fed into the computation. + +N.B. This operation will fail with an error if it is executed. It is +intended as a way to represent a value that will always be fed, and to +provide attrs that enable the fed value to be checked at runtime. + +output: A placeholder tensor that must be replaced using the feed mechanism. +dtype: The type of elements in the tensor. +shape: The shape of the tensor. The shape can be any partially-specified + shape. To be unconstrained, pass in a shape with unknown rank. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("PlaceholderWithDefault") .Input("input: dtype") diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 7024ad03ccb..189c1f42e57 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -670,6 +670,54 @@ TEST(ArrayOpsTest, Placeholder_ShapeFn) { } } +TEST(ArrayOpsTest, PlaceholderV2_ShapeFn) { + { + // 2D shape + ShapeInferenceTestOp op("PlaceholderV2"); + TensorShape shape({1, 2}); + TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2") + .Attr("shape", shape) + .Attr("dtype", DT_FLOAT) + .Finalize(&op.node_def)); + INFER_OK(op, "", "[1,2]"); + } + + { + // Scalar shapes are supported in V2. + ShapeInferenceTestOp op("PlaceholderV2"); + TensorShape shape({}); + TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2") + .Attr("shape", shape) + .Attr("dtype", DT_FLOAT) + .Finalize(&op.node_def)); + INFER_OK(op, "", "[]"); + } + + { + // Partial shape + ShapeInferenceTestOp op("PlaceholderV2"); + const int64 dims[2] = {1, -1}; + PartialTensorShape shape; + TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 2, &shape)); + TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2") + .Attr("shape", shape) + .Attr("dtype", DT_FLOAT) + .Finalize(&op.node_def)); + INFER_OK(op, "", "[1,?]"); + } + + { + // Unknown shape + ShapeInferenceTestOp op("PlaceholderV2"); + PartialTensorShape shape; + TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2") + .Attr("shape", shape) + .Attr("dtype", DT_FLOAT) + .Finalize(&op.node_def)); + INFER_OK(op, "", "?"); + } +} + TEST(ArrayOpsTest, Transpose_ShapeFn) { ShapeInferenceTestOp op("Transpose"); op.input_tensors.resize(2); diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 71ffe8c61df..14fe95dea66 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np import tensorflow as tf +from tensorflow.python.ops import array_ops + class ConstantTest(tf.test.TestCase): @@ -607,6 +609,101 @@ class PlaceholderTest(tf.test.TestCase): repr(c)) +class PlaceholderV2Test(tf.test.TestCase): + + def testDtype(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.float32, shape=None, name="p") + p_identity = tf.identity(p) + feed_array = np.random.rand(10, 10) + self.assertAllClose( + p_identity.eval(feed_dict={ + p: feed_array + }), feed_array) + + with self.assertRaisesOpError( + "must feed a value for placeholder tensor 'p' with dtype float"): + p_identity.eval() + + def testShape(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.float32, shape=(10, 10), name="p") + p_identity = tf.identity(p) + feed_array = np.random.rand(10, 10) + self.assertAllClose( + p_identity.eval(feed_dict={ + p: feed_array + }), feed_array) + + with self.assertRaisesOpError( + "must feed a value for placeholder tensor 'p' with dtype float and " + r"shape \[10,10\]"): + p_identity.eval() + + with self.assertRaisesWithPredicateMatch( + ValueError, lambda e: "Cannot feed value of shape" in str(e)): + p_identity.eval(feed_dict={p: feed_array[:5, :5]}) + + def testUnknownShape(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.float32, shape=None, name="p") + p_identity = tf.identity(p) + # can feed anything + feed_array = np.random.rand(10, 3) + self.assertAllClose( + p_identity.eval(feed_dict={ + p: feed_array + }), feed_array) + feed_array = np.random.rand(4, 2, 5) + self.assertAllClose( + p_identity.eval(feed_dict={ + p: feed_array + }), feed_array) + + def testScalarShape(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.float32, shape=[], name="p") + p_identity = tf.identity(p) + self.assertAllClose(p_identity.eval(feed_dict={p: 5}), 5) + + def testPartialShape(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.float32, shape=[None, 3], name="p") + p_identity = tf.identity(p) + feed_array = np.random.rand(10, 3) + self.assertAllClose( + p_identity.eval(feed_dict={ + p: feed_array + }), feed_array) + + with self.assertRaisesWithPredicateMatch( + ValueError, lambda e: "Cannot feed value of shape" in str(e)): + p_identity.eval(feed_dict={p: feed_array[:5, :2]}) + + def testControlDependency(self): + with self.test_session(): + p = array_ops.placeholder_v2(tf.int32, shape=[], name="p") + with tf.control_dependencies([p]): + c = tf.constant(5, tf.int32) + d = tf.mul(p, c) + val = np.array(2).astype(np.int) + self.assertEqual(10, d.eval(feed_dict={p: val})) + + def testBadShape(self): + with self.assertRaises(ValueError): + array_ops.placeholder_v2(tf.float32, shape=(-1, 10)) + + def testTensorStr(self): + a = array_ops.placeholder_v2(tf.float32, shape=None, name="a") + self.assertEqual(" dtype=float32>", repr(a)) + + b = array_ops.placeholder_v2(tf.int32, shape=(32, 40), name="b") + self.assertEqual("", repr(b)) + + c = array_ops.placeholder_v2(tf.qint32, shape=(32, None, 2), name="c") + self.assertEqual("", repr(c)) + + class PlaceholderWithDefaultTest(tf.test.TestCase): def testFullShape(self): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 6b4f371a8c9..725032fe592 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1534,6 +1534,7 @@ def meshgrid(*args, **kwargs): ops.RegisterShape("Placeholder")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("PlaceholderV2")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CheckNumerics")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Identity")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("RefIdentity")(common_shapes.call_cpp_shape_fn) -- GitLab