提交 c692c45d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

tf numpy: some changes to ndarray constructor logic.

PiperOrigin-RevId: 317765968
Change-Id: Iea4338ad18707ff36fc49b450d0defad5c13a6a2
上级 1a3b7af3
......@@ -141,13 +141,12 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
' Tensor or np.ndarray.'.format(type(buffer)))
if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access
# TODO(srbs): NumPy allows this. Investigate if/how to support this.
raise ValueError('shape arg must match buffer.shape.')
if shape is not None:
buffer.set_shape(shape)
assert isinstance(buffer, ops.Tensor)
if dtype and dtype != buffer.dtype:
buffer = array_ops.bitcast(buffer, dtype)
buffer = math_ops.cast(buffer, dtype)
self._data = buffer
self._type_spec_internal = None
......
......@@ -22,6 +22,7 @@ import collections
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
......@@ -51,6 +52,19 @@ class ArrayTest(test.TestCase):
self.assertIs(a.dtype.type, np.bool_)
self.assertAllEqual([False, True], a)
def testConstructor(self):
t = constant_op.constant([[1], [1]])
a = np_arrays.ndarray(shape=(2, 1), buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.float64, a.dtype)
a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.int32, a.dtype)
with self.assertRaises(ValueError): # bad shape
_ = np_arrays.ndarray((2, 2), buffer=t)
def testNeg(self):
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
self.assertAllEqual([-1.0, -2.0], -a)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册