未验证 提交 3b0dd5f6 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug that to_tensor not support paddle.Place (#28717)

上级 e1c8d6bc
...@@ -69,9 +69,12 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) { ...@@ -69,9 +69,12 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
return place_obj.cast<platform::XPUPlace>(); return place_obj.cast<platform::XPUPlace>();
} else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) { } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
return place_obj.cast<platform::CUDAPinnedPlace>(); return place_obj.cast<platform::CUDAPinnedPlace>();
} else if (py::isinstance<platform::Place>(place_obj)) {
return place_obj.cast<platform::Place>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace")); "Place should be one of "
"Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
} }
} }
......
...@@ -40,6 +40,9 @@ class TestVarBase(unittest.TestCase): ...@@ -40,6 +40,9 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(x.numpy(), [1])) self.assertTrue(np.array_equal(x.numpy(), [1]))
self.assertNotEqual(x.dtype, core.VarDesc.VarType.FP32) self.assertNotEqual(x.dtype, core.VarDesc.VarType.FP32)
y = paddle.to_tensor(2, place=x.place)
self.assertEqual(str(x.place), str(y.place))
# set_default_dtype should not take effect on numpy # set_default_dtype should not take effect on numpy
x = paddle.to_tensor( x = paddle.to_tensor(
np.array([1.2]).astype('float16'), np.array([1.2]).astype('float16'),
......
...@@ -126,10 +126,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ...@@ -126,10 +126,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
if place is None: if place is None:
place = _current_expected_place() place = _current_expected_place()
elif not isinstance(place, elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
(core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace)): core.CUDAPlace)):
raise ValueError( raise ValueError(
"'place' must be any of paddle.Place, paddle.CUDAPinnedPlace, paddle.CUDAPlace" "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace"
) )
#Todo(zhouwei): Support allocate tensor on any other specified card #Todo(zhouwei): Support allocate tensor on any other specified card
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册