提交 5e176998 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

Add PlaceholderV2 and VariableV2 ops that can properly specify

scalar shapes during construction.  Placholder/Variable currently
treat [] as "unknown", because it predated scalar shapes.  So
we introduce new versions that are the same implementation underneath,
but can differentiate between unknown and scalar.
Change: 136659818
上级 860a4e2b
......@@ -96,7 +96,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
if (e->src_output() == id.second) {
to_remove.emplace_back(e);
} else if (e->src_output() == Graph::kControlSlot &&
n->def().op() == "Placeholder") {
(n->def().op() == "Placeholder" ||
n->def().op() == "PlaceholderV2")) {
// When feeding a Placeholder node, any outgoing control edges
// will be replaced with a control edge from the replacement
// recv_node.
......
......@@ -246,10 +246,14 @@ class PlaceholderOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_CPU), PlaceholderOp);
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_CPU),
PlaceholderOp);
// The following GPU kernel registration is used to address the situation that
// a placeholder is added in a GPU device context and soft placement is false.
// Since a placeholder should never be executed, adding these GPU kernels has
// no effect on graph execution.
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_GPU), PlaceholderOp);
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE_GPU),
PlaceholderOp);
} // namespace tensorflow
......@@ -2409,6 +2409,36 @@ shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
shape is unconstrained.
)doc");
// This version fixes an issue with the original version of Placeholder
// where the empty shape attribute "[]" was used to denote
// an unknown shape. This meant that scalars (added later) could
// not be represented natively. This new version fixes that
// limitation.
REGISTER_OP("PlaceholderV2")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("shape: shape")
.SetShapeFn([](InferenceContext* c) {
TensorShapeProto shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
ShapeHandle output;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape, &output));
c->set_output(0, output);
return Status::OK();
})
.Doc(R"doc(
A placeholder op for a value that will be fed into the computation.
N.B. This operation will fail with an error if it is executed. It is
intended as a way to represent a value that will always be fed, and to
provide attrs that enable the fed value to be checked at runtime.
output: A placeholder tensor that must be replaced using the feed mechanism.
dtype: The type of elements in the tensor.
shape: The shape of the tensor. The shape can be any partially-specified
shape. To be unconstrained, pass in a shape with unknown rank.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("PlaceholderWithDefault")
.Input("input: dtype")
......
......@@ -670,6 +670,54 @@ TEST(ArrayOpsTest, Placeholder_ShapeFn) {
}
}
TEST(ArrayOpsTest, PlaceholderV2_ShapeFn) {
{
// 2D shape
ShapeInferenceTestOp op("PlaceholderV2");
TensorShape shape({1, 2});
TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2")
.Attr("shape", shape)
.Attr("dtype", DT_FLOAT)
.Finalize(&op.node_def));
INFER_OK(op, "", "[1,2]");
}
{
// Scalar shapes are supported in V2.
ShapeInferenceTestOp op("PlaceholderV2");
TensorShape shape({});
TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2")
.Attr("shape", shape)
.Attr("dtype", DT_FLOAT)
.Finalize(&op.node_def));
INFER_OK(op, "", "[]");
}
{
// Partial shape
ShapeInferenceTestOp op("PlaceholderV2");
const int64 dims[2] = {1, -1};
PartialTensorShape shape;
TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 2, &shape));
TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2")
.Attr("shape", shape)
.Attr("dtype", DT_FLOAT)
.Finalize(&op.node_def));
INFER_OK(op, "", "[1,?]");
}
{
// Unknown shape
ShapeInferenceTestOp op("PlaceholderV2");
PartialTensorShape shape;
TF_ASSERT_OK(NodeDefBuilder("test", "PlaceholderV2")
.Attr("shape", shape)
.Attr("dtype", DT_FLOAT)
.Finalize(&op.node_def));
INFER_OK(op, "", "?");
}
}
TEST(ArrayOpsTest, Transpose_ShapeFn) {
ShapeInferenceTestOp op("Transpose");
op.input_tensors.resize(2);
......
......@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import array_ops
class ConstantTest(tf.test.TestCase):
......@@ -607,6 +609,101 @@ class PlaceholderTest(tf.test.TestCase):
repr(c))
class PlaceholderV2Test(tf.test.TestCase):
def testDtype(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.float32, shape=None, name="p")
p_identity = tf.identity(p)
feed_array = np.random.rand(10, 10)
self.assertAllClose(
p_identity.eval(feed_dict={
p: feed_array
}), feed_array)
with self.assertRaisesOpError(
"must feed a value for placeholder tensor 'p' with dtype float"):
p_identity.eval()
def testShape(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.float32, shape=(10, 10), name="p")
p_identity = tf.identity(p)
feed_array = np.random.rand(10, 10)
self.assertAllClose(
p_identity.eval(feed_dict={
p: feed_array
}), feed_array)
with self.assertRaisesOpError(
"must feed a value for placeholder tensor 'p' with dtype float and "
r"shape \[10,10\]"):
p_identity.eval()
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Cannot feed value of shape" in str(e)):
p_identity.eval(feed_dict={p: feed_array[:5, :5]})
def testUnknownShape(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.float32, shape=None, name="p")
p_identity = tf.identity(p)
# can feed anything
feed_array = np.random.rand(10, 3)
self.assertAllClose(
p_identity.eval(feed_dict={
p: feed_array
}), feed_array)
feed_array = np.random.rand(4, 2, 5)
self.assertAllClose(
p_identity.eval(feed_dict={
p: feed_array
}), feed_array)
def testScalarShape(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.float32, shape=[], name="p")
p_identity = tf.identity(p)
self.assertAllClose(p_identity.eval(feed_dict={p: 5}), 5)
def testPartialShape(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.float32, shape=[None, 3], name="p")
p_identity = tf.identity(p)
feed_array = np.random.rand(10, 3)
self.assertAllClose(
p_identity.eval(feed_dict={
p: feed_array
}), feed_array)
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Cannot feed value of shape" in str(e)):
p_identity.eval(feed_dict={p: feed_array[:5, :2]})
def testControlDependency(self):
with self.test_session():
p = array_ops.placeholder_v2(tf.int32, shape=[], name="p")
with tf.control_dependencies([p]):
c = tf.constant(5, tf.int32)
d = tf.mul(p, c)
val = np.array(2).astype(np.int)
self.assertEqual(10, d.eval(feed_dict={p: val}))
def testBadShape(self):
with self.assertRaises(ValueError):
array_ops.placeholder_v2(tf.float32, shape=(-1, 10))
def testTensorStr(self):
a = array_ops.placeholder_v2(tf.float32, shape=None, name="a")
self.assertEqual("<tf.Tensor 'a:0' shape=<unknown> dtype=float32>", repr(a))
b = array_ops.placeholder_v2(tf.int32, shape=(32, 40), name="b")
self.assertEqual("<tf.Tensor 'b:0' shape=(32, 40) dtype=int32>", repr(b))
c = array_ops.placeholder_v2(tf.qint32, shape=(32, None, 2), name="c")
self.assertEqual("<tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32>", repr(c))
class PlaceholderWithDefaultTest(tf.test.TestCase):
def testFullShape(self):
......
......@@ -1534,6 +1534,7 @@ def meshgrid(*args, **kwargs):
ops.RegisterShape("Placeholder")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("PlaceholderV2")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CheckNumerics")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Identity")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("RefIdentity")(common_shapes.call_cpp_shape_fn)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册