未验证 提交 0cd577cf 编写于 作者: R ronnywang 提交者: GitHub

pybind support CustomPlace (#41136)

上级 bc88fbb5
......@@ -2182,6 +2182,7 @@ void BindImperative(py::module *m_ptr) {
m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CUDAPinnedPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::NPUPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::CustomPlace>);
m.def("varbase_copy", &VarBaseCopy<platform::MLUPlace>);
m.def(
......@@ -2341,6 +2342,11 @@ void BindImperative(py::module *m_ptr) {
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CustomPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
#if defined(PADDLE_WITH_CUDA)
m.def("to_uva_tensor",
......
......@@ -845,6 +845,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](framework::Tensor &self, const std::string &layout) {
self.set_layout(StringToDataLayout(layout));
})
.def("_alloc_float",
[](framework::Tensor &self, paddle::platform::CustomPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_float",
[](framework::Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<float>(place);
......@@ -873,6 +877,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](framework::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](framework::Tensor &self, paddle::platform::CustomPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](framework::Tensor &self, paddle::platform::XPUPlace &place) {
self.mutable_data<int>(place);
......@@ -901,6 +909,12 @@ PYBIND11_MODULE(core_noavx, m) {
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::CustomPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::XPUPlace &place,
paddle::framework::proto::VarType::Type type) {
......@@ -934,6 +948,8 @@ PYBIND11_MODULE(core_noavx, m) {
})
.def("_copy_from", &TensorCopyFrom<paddle::platform::CPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CustomPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::XPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPlace>,
......@@ -948,6 +964,8 @@ PYBIND11_MODULE(core_noavx, m) {
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::CustomPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::XPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
.def("set", SetTensorFromPyArray<paddle::platform::CUDAPlace>,
......@@ -1985,6 +2003,19 @@ All parameter, weight, gradient are variables in Paddle.
"Please recompile or reinstall Paddle with NPU support."));
#else
return new paddle::platform::NPUDeviceContext(place);
#endif
})
.def_static("create",
[](paddle::platform::CustomPlace& place)
-> paddle::platform::DeviceContext* {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(
platform::errors::PermissionDenied(
"Cannot use CustomPlace in CPU/GPU/XPU version, "
"Please recompile or reinstall Paddle with "
"CustomDevice support."));
#else
return new paddle::platform::CustomDeviceContext(place);
#endif
})
.def_static("create",
......@@ -2722,6 +2753,12 @@ All parameter, weight, gradient are variables in Paddle.
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CustomPlace &place) {
pybind11::gil_scoped_release release;
self.Run(scope, place);
})
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册