未验证 提交 ae35f502 编写于 作者: J JYChen 提交者: GitHub

fix device changed in setitem-numpy case (#53987)

上级 89b73ef1
......@@ -1313,7 +1313,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
}
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
auto self_numpy = TensorToPyArray(*self_tensor, true);
VLOG(4) << "parse_index is false";
if (PyCheckTensor(_index)) {
VLOG(4) << "index is tensor";
......
......@@ -1635,6 +1635,20 @@ class TestSetValueInplaceLeafVar(unittest.TestCase):
paddle.enable_static()
class TestSetValueIsSamePlace(unittest.TestCase):
def test_is_same_place(self):
paddle.disable_static()
paddle.seed(100)
paddle.set_device('cpu')
a = paddle.rand(shape=[2, 3, 4])
origin_place = a.place
a[[0, 1], 1] = 10
self.assertEqual(origin_place._type(), a.place._type())
if paddle.is_compiled_with_cuda():
paddle.set_device('gpu')
paddle.enable_static()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册