diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 64413685c6c673708b8246c63beb39acf2bf0f69..4c6174f25ba4570abb3ca8dfc73957ea4fbead8c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -239,6 +239,21 @@ PYBIND11_MODULE(core_noavx, m) { [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) + .def("_mutable_data", + [](Tensor &self, paddle::platform::CPUPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) + .def("_mutable_data", + [](Tensor &self, paddle::platform::CUDAPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) + .def("_mutable_data", + [](Tensor &self, paddle::platform::CUDAPinnedPlace &place, + paddle::framework::proto::VarType::Type type) { + return reinterpret_cast(self.mutable_data(place, type)); + }) .def("_clear", &Tensor::clear) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) diff --git a/python/paddle/fluid/tests/unittests/test_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor.py index 4615511ed85441551ed3a5071a8cf1d0dfe32984..ec180456acf00c70a9593ea12b728fe1b65335c1 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor.py @@ -18,6 +18,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core import unittest import numpy +import numbers class TestTensor(unittest.TestCase): @@ -171,7 +172,6 @@ class TestTensor(unittest.TestCase): var = scope.var("test_tensor") tensor = var.get_tensor() - tensor._set_dims([0, 1]) tensor._alloc_float(place) @@ -256,6 +256,26 @@ class TestTensor(unittest.TestCase): print(tensor) self.assertTrue(isinstance(str(tensor), str)) + def test_tensor_poiter(self): + place = core.CPUPlace() + scope = core.Scope() + var = scope.var("test_tensor") + place = core.CPUPlace() + tensor = var.get_tensor() + dtype = core.VarDesc.VarType.FP32 + self.assertTrue( + isinstance(tensor._mutable_data(place, dtype), numbers.Integral)) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.assertTrue( + isinstance( + tensor._mutable_data(place, dtype), numbers.Integral)) + place = core.CUDAPinnedPlace() + self.assertTrue( + isinstance( + tensor._mutable_data(place, dtype), numbers.Integral)) + if __name__ == '__main__': unittest.main()