未验证 提交 3b0dd5f6 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug that to_tensor not support paddle.Place (#28717)

上级 e1c8d6bc
......@@ -69,9 +69,12 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
return place_obj.cast<platform::XPUPlace>();
} else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
return place_obj.cast<platform::CUDAPinnedPlace>();
} else if (py::isinstance<platform::Place>(place_obj)) {
return place_obj.cast<platform::Place>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
"Place should be one of "
"Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
}
}
......
......@@ -40,6 +40,9 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(x.numpy(), [1]))
self.assertNotEqual(x.dtype, core.VarDesc.VarType.FP32)
y = paddle.to_tensor(2, place=x.place)
self.assertEqual(str(x.place), str(y.place))
# set_default_dtype should not take effect on numpy
x = paddle.to_tensor(
np.array([1.2]).astype('float16'),
......
......@@ -126,10 +126,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
if place is None:
place = _current_expected_place()
elif not isinstance(place,
(core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace)):
elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
core.CUDAPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CUDAPinnedPlace, paddle.CUDAPlace"
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace"
)
#Todo(zhouwei): Support allocate tensor on any other specified card
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册