未验证 提交 035c7425 编写于 作者: Z Zhou Wei 提交者: GitHub

add API Tensor.item() to convert Tensor element to a Python scalar (#32634)

cherry-pick #32561
上级 cdfc34d2
......@@ -784,6 +784,70 @@ void BindImperative(py::module *m_ptr) {
return out;
}
})
.def(
"_getitem_from_offset",
[](std::shared_ptr<imperative::VarBase> &self, const py::args &args) {
const auto &tensor = self->Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor of %s is Empty, please check if it has no data.",
self->Name()));
const auto &tensor_dims = tensor.dims();
std::vector<size_t> dims(tensor_dims.size());
std::vector<size_t> strides(tensor_dims.size());
size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
strides[i] = numel;
dims[i] = static_cast<size_t>(tensor_dims[i]);
numel *= dims[i];
}
size_t offset = 0;
if (args.empty()) {
PADDLE_ENFORCE_EQ(
numel, 1,
platform::errors::InvalidArgument(
"only one element tensors can be converted to Python "
"scalars when no input coordinates"));
} else if (args.size() == 1) {
offset = args[0].cast<size_t>();
PADDLE_ENFORCE_LT(
offset, numel,
platform::errors::InvalidArgument(
"index %d is out of bounds for size %d", offset, numel));
} else {
PADDLE_ENFORCE_EQ(args.size(), dims.size(),
platform::errors::InvalidArgument(
"incorrect number of indices for Tensor"));
for (size_t i = 0; i < args.size(); ++i) {
size_t index = args[i].cast<size_t>();
PADDLE_ENFORCE_LT(
index, dims[i],
platform::errors::InvalidArgument(
"index %d is out fo bounds for axis %d with size %d",
index, i, dims[i]));
offset += index * strides[i];
}
}
#define TENSOR_TO_PY_SCALAR(T, proto_type) \
if (tensor.type() == proto_type) { \
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(proto_type); \
T b = TensorGetElement<T>(tensor, offset); \
return py::array(py::dtype(py_dtype_str.c_str()), {}, {}, \
static_cast<void *>(&b)); \
}
_ForEachDataType_(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor data type: %s",
framework::DataTypeToString(tensor.type())));
},
py::return_value_policy::copy)
.def("_inplace_version",
[](imperative::VarBase &self) -> uint32_t {
const auto &var = self.MutableVar();
......
......@@ -375,6 +375,49 @@ def monkey_patch_varbase():
"""
self.clear_gradient()
def item(self, *args):
"""
Convert one element Tensor to a Python scalar.
Args:
*args(int): The input coordinates. If it's single int, the data in the corresponding order of flattened Tensor will be returned.
Default: None, and it must be in the case where Tensor has only one element.
Returns(Python scalar): A Python scalar, whose dtype is corresponds to the dtype of Tensor.
Raises:
ValueError: If the Tensor has more than one element, there must be coordinates.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1)
print(x.item()) #1
print(type(x.item())) #<class 'int'>
x = paddle.to_tensor(1.0)
print(x.item()) #1.0
print(type(x.item())) #<class 'float'>
x = paddle.to_tensor(True)
print(x.item()) #True
print(type(x.item())) #<class 'bool'>
x = paddle.to_tensor(1+1j)
print(x.item()) #(1+1j)
print(type(x.item())) #<class 'complex'>
x = paddle.to_tensor([[1.1, 2.2, 3.3]])
print(x.item(2)) #3.3
print(x.item(0, 2)) #3.3
x = paddle.to_tensor([1, 2])
x.item() #ValueError: only one element tensor can be converted to Python scalar when no input coordinates.
"""
return self._getitem_from_offset(*args).item()
@property
def inplace_version(self):
"""
......@@ -462,7 +505,30 @@ def monkey_patch_varbase():
return self.__nonzero__()
def __array__(self, dtype=None):
return self.numpy().astype(dtype)
"""
Returns a numpy array shows the value of current Tensor.
Returns:
ndarray: The numpy value of current Tensor.
Returns type:
ndarray: dtype is same as current Tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.randn([2, 2])
x_array = np.array(x)
print(type(x_array)) #<class 'numpy.ndarray'>
print(x_array.shape) #(2, 2)
"""
array = self.numpy()
if dtype:
array = array.astype(dtype)
return array
def __getitem__(self, item):
def contain_tensor(item):
......@@ -498,7 +564,7 @@ def monkey_patch_varbase():
("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor"), ("__array__", __array__),
("__getitem__", __getitem__)):
("__getitem__", __getitem__), ("item", item)):
setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
......
......@@ -143,6 +143,74 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(y.dtype, core.VarDesc.VarType.COMPLEX64)
self.assertEqual(y.shape, [2])
paddle.set_default_dtype('float32')
x = paddle.randn([3, 4])
x_array = np.array(x)
self.assertEqual(x_array.shape, x.numpy().shape)
self.assertEqual(x_array.dtype, x.numpy().dtype)
self.assertTrue(np.array_equal(x_array, x.numpy()))
x = paddle.to_tensor(1.0)
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
x = paddle.randn([3, 2, 2])
self.assertTrue(isinstance(x.item(5), float))
self.assertTrue(isinstance(x.item(1, 0, 1), float))
self.assertEqual(x.item(5), x.item(1, 0, 1))
self.assertTrue(
np.array_equal(x.item(1, 0, 1), x.numpy().item(1, 0, 1)))
x = paddle.to_tensor([[1.111111, 2.222222, 3.333333]])
self.assertEqual(x.item(0, 2), x.item(2))
self.assertAlmostEqual(x.item(2), 3.333333)
self.assertTrue(isinstance(x.item(0, 2), float))
x = paddle.to_tensor(1.0, dtype='float64')
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
x = paddle.to_tensor(1.0, dtype='float16')
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
x = paddle.to_tensor(1, dtype='uint8')
self.assertEqual(x.item(), 1)
print(type(x.item()))
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(1, dtype='int8')
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(1, dtype='int16')
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(1, dtype='int32')
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(1, dtype='int64')
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), long if six.PY2 else int))
x = paddle.to_tensor(True)
self.assertEqual(x.item(), True)
self.assertTrue(isinstance(x.item(), bool))
x = paddle.to_tensor(1 + 1j)
self.assertEqual(x.item(), 1 + 1j)
self.assertTrue(isinstance(x.item(), complex))
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item()
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item(18)
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item(1, 2)
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item(2, 1, 2)
with self.assertRaises(TypeError):
paddle.to_tensor('test')
with self.assertRaises(TypeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册