未验证 提交 cde73a7b 编写于 作者: Y Yang Zhang 提交者: GitHub

Expose `mutable_data` as python binding (#19932)

* Expose `mutable_data` as python binding

test=develop

* Add test for device pointer binding

test=develop

* Make test compatible with python 2
上级 137e6336
......@@ -239,6 +239,21 @@ PYBIND11_MODULE(core_noavx, m) {
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<float>(place);
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CUDAPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_mutable_data",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_clear", &Tensor::clear)
.def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>)
......
......@@ -18,6 +18,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
import unittest
import numpy
import numbers
class TestTensor(unittest.TestCase):
......@@ -171,7 +172,6 @@ class TestTensor(unittest.TestCase):
var = scope.var("test_tensor")
tensor = var.get_tensor()
tensor._set_dims([0, 1])
tensor._alloc_float(place)
......@@ -256,6 +256,26 @@ class TestTensor(unittest.TestCase):
print(tensor)
self.assertTrue(isinstance(str(tensor), str))
def test_tensor_poiter(self):
place = core.CPUPlace()
scope = core.Scope()
var = scope.var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
dtype = core.VarDesc.VarType.FP32
self.assertTrue(
isinstance(tensor._mutable_data(place, dtype), numbers.Integral))
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.assertTrue(
isinstance(
tensor._mutable_data(place, dtype), numbers.Integral))
place = core.CUDAPinnedPlace()
self.assertTrue(
isinstance(
tensor._mutable_data(place, dtype), numbers.Integral))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册