未验证 提交 0f9ec013 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Bug-fix] fix bug of Tensor.item() when CUDAPinnedPlace (#52322)

上级 f1cdd654
...@@ -980,7 +980,14 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, ...@@ -980,7 +980,14 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
auto ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get()); phi::DenseTensor* ptr = nullptr;
if (self->tensor.is_selected_rows()) {
auto* selected_rows =
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
} else {
ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
}
PADDLE_ENFORCE_NOT_NULL(ptr, PADDLE_ENFORCE_NOT_NULL(ptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"%s is not a DenseTensor.", self->tensor.name())); "%s is not a DenseTensor.", self->tensor.name()));
......
...@@ -276,7 +276,8 @@ T TensorGetElement(const phi::DenseTensor &self, size_t offset) { ...@@ -276,7 +276,8 @@ T TensorGetElement(const phi::DenseTensor &self, size_t offset) {
"The offset exceeds the size of tensor.")); "The offset exceeds the size of tensor."));
T b = static_cast<T>(0); T b = static_cast<T>(0);
if (platform::is_cpu_place(self.place())) { if (platform::is_cpu_place(self.place()) ||
platform::is_cuda_pinned_place(self.place())) {
b = self.data<T>()[offset]; b = self.data<T>()[offset];
} else if (platform::is_xpu_place(self.place())) { } else if (platform::is_xpu_place(self.place())) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
......
...@@ -168,7 +168,7 @@ class TestVarBase(unittest.TestCase): ...@@ -168,7 +168,7 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(x_array.dtype, x.numpy().dtype) self.assertEqual(x_array.dtype, x.numpy().dtype)
np.testing.assert_array_equal(x_array, x.numpy()) np.testing.assert_array_equal(x_array, x.numpy())
x = paddle.to_tensor(1.0) x = paddle.to_tensor(1.0, place=place)
self.assertEqual(x.item(), 1.0) self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float)) self.assertTrue(isinstance(x.item(), float))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册