diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 1913973b087c571eafbe9c5f8e4c65e66050d8de..64276c82acb737ded774c4aed700ddd2836be43a 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -980,7 +980,14 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - auto ptr = static_cast(self->tensor.impl().get()); + phi::DenseTensor* ptr = nullptr; + if (self->tensor.is_selected_rows()) { + auto* selected_rows = + static_cast(self->tensor.impl().get()); + ptr = static_cast(selected_rows->mutable_value()); + } else { + ptr = static_cast(self->tensor.impl().get()); + } PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument( "%s is not a DenseTensor.", self->tensor.name())); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index ce264bdf5f4d87c076faec41d8d0dbf4b2017c3e..e050fc7c7d5446fc724471667f81103cdd4a393d 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -276,7 +276,8 @@ T TensorGetElement(const phi::DenseTensor &self, size_t offset) { "The offset exceeds the size of tensor.")); T b = static_cast(0); - if (platform::is_cpu_place(self.place())) { + if (platform::is_cpu_place(self.place()) || + platform::is_cuda_pinned_place(self.place())) { b = self.data()[offset]; } else if (platform::is_xpu_place(self.place())) { #ifdef PADDLE_WITH_XPU diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index a0e6253fb3f0cc9abc199c6310bd158396092145..17098ef9425a97ed2e6ea665aa6aea101f950761 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -168,7 +168,7 @@ class TestVarBase(unittest.TestCase): self.assertEqual(x_array.dtype, x.numpy().dtype) np.testing.assert_array_equal(x_array, x.numpy()) - x = paddle.to_tensor(1.0) + x = paddle.to_tensor(1.0, place=place) self.assertEqual(x.item(), 1.0) self.assertTrue(isinstance(x.item(), float))