未验证 提交 a85edddb 编写于 作者: L Leo Chen 提交者: GitHub

paddle.to_tensor supports LoDTensor (#33027)

上级 44668a7a
......@@ -248,6 +248,21 @@ class TestVarBase(unittest.TestCase):
a = paddle.to_tensor(a, place=paddle.CUDAPinnedPlace())
self.assertEqual(a.place.__repr__(), "CUDAPinnedPlace")
def test_to_tensor_with_lodtensor(self):
if core.is_compiled_with_cuda():
a_np = np.random.rand(1024, 1024)
with paddle.fluid.dygraph.guard(core.CPUPlace()):
lod_tensor = core.LoDTensor()
lod_tensor.set(a_np, core.CPUPlace())
a = paddle.to_tensor(lod_tensor)
self.assertTrue(np.array_equal(a_np, a.numpy()))
with paddle.fluid.dygraph.guard(core.CUDAPlace(0)):
lod_tensor = core.LoDTensor()
lod_tensor.set(a_np, core.CUDAPlace(0))
a = paddle.to_tensor(lod_tensor)
self.assertTrue(np.array_equal(a_np, a.numpy()))
def test_to_variable(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array, name="abc")
......
......@@ -118,6 +118,16 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
place = _current_expected_place()
if not isinstance(data, np.ndarray):
def _handle_diff_place_dtype(data, dtype, place, stop_gradient):
data.stop_gradient = stop_gradient
if not data.place._equals(place):
data = data._copy_to(place, False)
if dtype:
if convert_dtype(dtype) != convert_dtype(data.dtype):
return data.astype(convert_dtype(dtype))
return data
if np.isscalar(data) and not isinstance(data, str):
data = np.array([data])
elif isinstance(data, (list, tuple)):
......@@ -128,13 +138,11 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
"this means the input data contains nested lists with different lengths. "
)
elif isinstance(data, paddle.Tensor):
data.stop_gradient = stop_gradient
if not data.place._equals(place):
data = data._copy_to(place, False)
if dtype:
if convert_dtype(dtype) != convert_dtype(data.dtype):
return data.astype(convert_dtype(dtype))
return data
return _handle_diff_place_dtype(data, dtype, place, stop_gradient)
elif isinstance(data, (core.Tensor, core.LoDTensor)):
# convert LoDTensor to VarBase first, and then process it as input VarBase
data = paddle.Tensor(data)
return _handle_diff_place_dtype(data, dtype, place, stop_gradient)
else:
raise TypeError(
"Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册