未验证 提交 24b2e8e6 编写于 作者: S sneaxiy 提交者: GitHub

add _ptr for tensor (#39357)

上级 ed0990e7
......@@ -799,6 +799,10 @@ PYBIND11_MODULE(core_noavx, m) {
framework_tensor
.def("__array__",
[](framework::Tensor &self) { return TensorToPyArray(self); })
.def("_ptr",
[](const framework::Tensor &self) {
return reinterpret_cast<uintptr_t>(self.data());
})
.def("_is_initialized",
[](const framework::Tensor &self) { return self.IsInitialized(); })
.def("_get_dims",
......
......@@ -21,6 +21,14 @@ import numpy
import numbers
class TestTensorPtr(unittest.TestCase):
def test_tensor_ptr(self):
t = core.Tensor()
np_arr = numpy.zeros([2, 3])
t.set(np_arr, core.CPUPlace())
self.assertGreater(t._ptr(), 0)
class TestTensor(unittest.TestCase):
def setUp(self):
self.support_dtypes = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册