From 7e2b60a4a5cdc4f022226e01ce6acdfbc83807f8 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Fri, 30 Apr 2021 13:52:22 +0800 Subject: [PATCH] add API Tensor.item() to convert Tensor element to a Python scalar (#32561) --- paddle/fluid/pybind/imperative.cc | 64 +++++++++++++++++ .../fluid/dygraph/varbase_patch_methods.py | 70 ++++++++++++++++++- .../fluid/tests/unittests/test_var_base.py | 68 ++++++++++++++++++ 3 files changed, 200 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 93441eb52fe..450c992d411 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -784,6 +784,70 @@ void BindImperative(py::module *m_ptr) { return out; } }) + .def( + "_getitem_from_offset", + [](std::shared_ptr &self, const py::args &args) { + const auto &tensor = self->Var().Get(); + PADDLE_ENFORCE_EQ( + tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "Tensor of %s is Empty, please check if it has no data.", + self->Name())); + + const auto &tensor_dims = tensor.dims(); + + std::vector dims(tensor_dims.size()); + std::vector strides(tensor_dims.size()); + + size_t numel = 1; + for (int i = tensor_dims.size() - 1; i >= 0; --i) { + strides[i] = numel; + dims[i] = static_cast(tensor_dims[i]); + numel *= dims[i]; + } + size_t offset = 0; + if (args.empty()) { + PADDLE_ENFORCE_EQ( + numel, 1, + platform::errors::InvalidArgument( + "only one element tensors can be converted to Python " + "scalars when no input coordinates")); + } else if (args.size() == 1) { + offset = args[0].cast(); + PADDLE_ENFORCE_LT( + offset, numel, + platform::errors::InvalidArgument( + "index %d is out of bounds for size %d", offset, numel)); + } else { + PADDLE_ENFORCE_EQ(args.size(), dims.size(), + platform::errors::InvalidArgument( + "incorrect number of indices for Tensor")); + + for (size_t i = 0; i < args.size(); ++i) { + size_t index = args[i].cast(); + PADDLE_ENFORCE_LT( + index, dims[i], + platform::errors::InvalidArgument( + "index %d is out fo bounds for axis %d with size %d", + index, i, dims[i])); + offset += index * strides[i]; + } + } +#define TENSOR_TO_PY_SCALAR(T, proto_type) \ + if (tensor.type() == proto_type) { \ + std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(proto_type); \ + T b = TensorGetElement(tensor, offset); \ + return py::array(py::dtype(py_dtype_str.c_str()), {}, {}, \ + static_cast(&b)); \ + } + + _ForEachDataType_(TENSOR_TO_PY_SCALAR); +#undef TENSOR_TO_PY_SCALAR + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported tensor data type: %s", + framework::DataTypeToString(tensor.type()))); + }, + py::return_value_policy::copy) .def("_inplace_version", [](imperative::VarBase &self) -> uint32_t { const auto &var = self.MutableVar(); diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index dbc2b24aeea..bb84b2ca970 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -375,6 +375,49 @@ def monkey_patch_varbase(): """ self.clear_gradient() + def item(self, *args): + """ + Convert one element Tensor to a Python scalar. + + Args: + *args(int): The input coordinates. If it's single int, the data in the corresponding order of flattened Tensor will be returned. + Default: None, and it must be in the case where Tensor has only one element. + + Returns(Python scalar): A Python scalar, whose dtype is corresponds to the dtype of Tensor. + + Raises: + ValueError: If the Tensor has more than one element, there must be coordinates. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor(1) + print(x.item()) #1 + print(type(x.item())) # + + x = paddle.to_tensor(1.0) + print(x.item()) #1.0 + print(type(x.item())) # + + x = paddle.to_tensor(True) + print(x.item()) #True + print(type(x.item())) # + + x = paddle.to_tensor(1+1j) + print(x.item()) #(1+1j) + print(type(x.item())) # + + x = paddle.to_tensor([[1.1, 2.2, 3.3]]) + print(x.item(2)) #3.3 + print(x.item(0, 2)) #3.3 + + x = paddle.to_tensor([1, 2]) + x.item() #ValueError: only one element tensor can be converted to Python scalar when no input coordinates. + """ + return self._getitem_from_offset(*args).item() + @property def inplace_version(self): """ @@ -462,7 +505,30 @@ def monkey_patch_varbase(): return self.__nonzero__() def __array__(self, dtype=None): - return self.numpy().astype(dtype) + """ + Returns a numpy array shows the value of current Tensor. + + Returns: + ndarray: The numpy value of current Tensor. + + Returns type: + ndarray: dtype is same as current Tensor + + Examples: + .. code-block:: python + + import paddle + import numpy as np + x = paddle.randn([2, 2]) + x_array = np.array(x) + + print(type(x_array)) # + print(x_array.shape) #(2, 2) + """ + array = self.numpy() + if dtype: + array = array.astype(dtype) + return array def __getitem__(self, item): def contain_tensor(item): @@ -498,7 +564,7 @@ def monkey_patch_varbase(): ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), ("__name__", "Tensor"), ("__array__", __array__), - ("__getitem__", __getitem__)): + ("__getitem__", __getitem__), ("item", item)): setattr(core.VarBase, method_name, method) # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 8bf42390d1e..83f02b629d7 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -143,6 +143,74 @@ class TestVarBase(unittest.TestCase): self.assertEqual(y.dtype, core.VarDesc.VarType.COMPLEX64) self.assertEqual(y.shape, [2]) + paddle.set_default_dtype('float32') + x = paddle.randn([3, 4]) + x_array = np.array(x) + self.assertEqual(x_array.shape, x.numpy().shape) + self.assertEqual(x_array.dtype, x.numpy().dtype) + self.assertTrue(np.array_equal(x_array, x.numpy())) + + x = paddle.to_tensor(1.0) + self.assertEqual(x.item(), 1.0) + self.assertTrue(isinstance(x.item(), float)) + + x = paddle.randn([3, 2, 2]) + self.assertTrue(isinstance(x.item(5), float)) + self.assertTrue(isinstance(x.item(1, 0, 1), float)) + self.assertEqual(x.item(5), x.item(1, 0, 1)) + self.assertTrue( + np.array_equal(x.item(1, 0, 1), x.numpy().item(1, 0, 1))) + + x = paddle.to_tensor([[1.111111, 2.222222, 3.333333]]) + self.assertEqual(x.item(0, 2), x.item(2)) + self.assertAlmostEqual(x.item(2), 3.333333) + self.assertTrue(isinstance(x.item(0, 2), float)) + + x = paddle.to_tensor(1.0, dtype='float64') + self.assertEqual(x.item(), 1.0) + self.assertTrue(isinstance(x.item(), float)) + + x = paddle.to_tensor(1.0, dtype='float16') + self.assertEqual(x.item(), 1.0) + self.assertTrue(isinstance(x.item(), float)) + + x = paddle.to_tensor(1, dtype='uint8') + self.assertEqual(x.item(), 1) + print(type(x.item())) + self.assertTrue(isinstance(x.item(), int)) + + x = paddle.to_tensor(1, dtype='int8') + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + x = paddle.to_tensor(1, dtype='int16') + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + x = paddle.to_tensor(1, dtype='int32') + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + x = paddle.to_tensor(1, dtype='int64') + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), long if six.PY2 else int)) + + x = paddle.to_tensor(True) + self.assertEqual(x.item(), True) + self.assertTrue(isinstance(x.item(), bool)) + + x = paddle.to_tensor(1 + 1j) + self.assertEqual(x.item(), 1 + 1j) + self.assertTrue(isinstance(x.item(), complex)) + + with self.assertRaises(ValueError): + paddle.randn([3, 2, 2]).item() + with self.assertRaises(ValueError): + paddle.randn([3, 2, 2]).item(18) + with self.assertRaises(ValueError): + paddle.randn([3, 2, 2]).item(1, 2) + with self.assertRaises(ValueError): + paddle.randn([3, 2, 2]).item(2, 1, 2) with self.assertRaises(TypeError): paddle.to_tensor('test') with self.assertRaises(TypeError): -- GitLab