From 0f9ec0133ac23d91fb7f2010d703dfa685ff82b1 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 30 Mar 2023 21:06:56 +0800 Subject: [PATCH] [Bug-fix] fix bug of Tensor.item() when CUDAPinnedPlace (#52322) --- paddle/fluid/pybind/eager_method.cc | 9 ++++++++- paddle/fluid/pybind/tensor_py.h | 3 ++- python/paddle/fluid/tests/unittests/test_var_base.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 1913973b087..64276c82acb 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -980,7 +980,14 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - auto ptr = static_cast(self->tensor.impl().get()); + phi::DenseTensor* ptr = nullptr; + if (self->tensor.is_selected_rows()) { + auto* selected_rows = + static_cast(self->tensor.impl().get()); + ptr = static_cast(selected_rows->mutable_value()); + } else { + ptr = static_cast(self->tensor.impl().get()); + } PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument( "%s is not a DenseTensor.", self->tensor.name())); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index ce264bdf5f4..e050fc7c7d5 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -276,7 +276,8 @@ T TensorGetElement(const phi::DenseTensor &self, size_t offset) { "The offset exceeds the size of tensor.")); T b = static_cast(0); - if (platform::is_cpu_place(self.place())) { + if (platform::is_cpu_place(self.place()) || + platform::is_cuda_pinned_place(self.place())) { b = self.data()[offset]; } else if (platform::is_xpu_place(self.place())) { #ifdef PADDLE_WITH_XPU diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index a0e6253fb3f..17098ef9425 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -168,7 +168,7 @@ class TestVarBase(unittest.TestCase): self.assertEqual(x_array.dtype, x.numpy().dtype) np.testing.assert_array_equal(x_array, x.numpy()) - x = paddle.to_tensor(1.0) + x = paddle.to_tensor(1.0, place=place) self.assertEqual(x.item(), 1.0) self.assertTrue(isinstance(x.item(), float)) -- GitLab