未验证 提交 b7a54fc1 编写于 作者: Z Zhou Wei 提交者: GitHub

support convert core.Tensor to paddle.Tensor (#33430)

上级 e47c3f04
......@@ -245,7 +245,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
}
static void InitVarBaseFromTensorWithArgDefault(
imperative::VarBase *self, const framework::LoDTensor &tensor) {
imperative::VarBase *self, const framework::Tensor &tensor) {
VLOG(4) << "Init VarBase";
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
new (self) imperative::VarBase(
......
......@@ -176,7 +176,6 @@ class TestVarBase(unittest.TestCase):
x = paddle.to_tensor(1, dtype='uint8')
self.assertEqual(x.item(), 1)
print(type(x.item()))
self.assertTrue(isinstance(x.item(), int))
x = paddle.to_tensor(1, dtype='int8')
......@@ -203,6 +202,24 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(x.item(), 1 + 1j)
self.assertTrue(isinstance(x.item(), complex))
numpy_array = np.random.randn(3, 4)
# covert core.LoDTensor to paddle.Tensor
lod_tensor = paddle.fluid.core.LoDTensor()
place = paddle.fluid.framework._current_expected_place()
lod_tensor.set(numpy_array, place)
x = paddle.to_tensor(lod_tensor)
self.assertTrue(np.array_equal(x.numpy(), numpy_array))
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
self.assertEqual(str(x.place), str(place))
# covert core.Tensor to paddle.Tensor
x = paddle.to_tensor(numpy_array)
dlpack = x.value().get_tensor()._to_dlpack()
tensor_from_dlpack = paddle.fluid.core.from_dlpack(dlpack)
x = paddle.to_tensor(tensor_from_dlpack)
self.assertTrue(np.array_equal(x.numpy(), numpy_array))
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
with self.assertRaises(ValueError):
paddle.randn([3, 2, 2]).item()
with self.assertRaises(ValueError):
......
......@@ -136,9 +136,10 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
data = data._copy_to(place, False)
ata = _handle_dtype(data, dtype)
data.stop_gradient = stop_gradient
elif isinstance(data, core.LoDTensor):
# convert LoDTensor to VarBase first
# Currenly, LoDTensor does no copy when places are same
elif isinstance(data, (core.LoDTensor, core.Tensor)):
# Note(zhouwei25): should't expose it to users, just for internal use.
# convert core.Tensor/core.LoDTensor to VarBase first
# Currenly, there is no copy when places are same
data = paddle.Tensor(data)
if not data.place._equals(place):
data = data._copy_to(place, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册