提交 d1a7879a 编写于 作者: E Edward Loper 提交者: TensorFlower Gardener

In NestedStructureCoder: fixed bug where a TensorSpec with name=None would get...

In NestedStructureCoder: fixed bug where a TensorSpec with name=None would get deserialized with name=''.

PiperOrigin-RevId: 257610414
上级 1fd772b3
......@@ -423,6 +423,7 @@ class _TensorSpecCodec(object):
return value.HasField("tensor_spec_value")
def do_decode(self, value, decode_fn):
name = value.tensor_spec_value.name
return tensor_spec.TensorSpec(
shape=decode_fn(
struct_pb2.StructuredValue(
......@@ -430,7 +431,7 @@ class _TensorSpecCodec(object):
dtype=decode_fn(
struct_pb2.StructuredValue(
tensor_dtype_value=value.tensor_spec_value.dtype)),
name=value.tensor_spec_value.name)
name=(name if name else None))
StructureCoder.register_codec(_TensorSpecCodec())
......@@ -171,6 +171,22 @@ class NestedStructureTest(test.TestCase):
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testEncodeDecodeTensorSpecWithNoName(self):
structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)]
self.assertTrue(self._coder.can_encode(structure))
encoded = self._coder.encode_structure(structure)
expected = struct_pb2.StructuredValue()
expected_list = expected.list_value
expected_tensor_spec = expected_list.values.add().tensor_spec_value
expected_tensor_spec.shape.dim.add().size = 1
expected_tensor_spec.shape.dim.add().size = 2
expected_tensor_spec.shape.dim.add().size = 3
expected_tensor_spec.name = ""
expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
self.assertEqual(expected, encoded)
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testNotEncodable(self):
class NotEncodable(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册