From 24b2e8e6c84ec6e75f561c51f170faf76ec70374 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 7 Feb 2022 17:09:53 +0800 Subject: [PATCH] add _ptr for tensor (#39357) --- paddle/fluid/pybind/pybind.cc | 4 ++++ python/paddle/fluid/tests/unittests/test_tensor.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 9a535f5fb04..a5c4bb1a804 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -799,6 +799,10 @@ PYBIND11_MODULE(core_noavx, m) { framework_tensor .def("__array__", [](framework::Tensor &self) { return TensorToPyArray(self); }) + .def("_ptr", + [](const framework::Tensor &self) { + return reinterpret_cast(self.data()); + }) .def("_is_initialized", [](const framework::Tensor &self) { return self.IsInitialized(); }) .def("_get_dims", diff --git a/python/paddle/fluid/tests/unittests/test_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor.py index f8f3eea78a6..da792903b7d 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor.py @@ -21,6 +21,14 @@ import numpy import numbers +class TestTensorPtr(unittest.TestCase): + def test_tensor_ptr(self): + t = core.Tensor() + np_arr = numpy.zeros([2, 3]) + t.set(np_arr, core.CPUPlace()) + self.assertGreater(t._ptr(), 0) + + class TestTensor(unittest.TestCase): def setUp(self): self.support_dtypes = [ -- GitLab