提交 6af17835 编写于 作者: C chengduoZH

expose CUDAPinnedPlace to Python

上级 f2c0b886
...@@ -11,11 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,11 +11,16 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <map>
#include <mutex> // NOLINT // for call_once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/protobuf.h"
#include <mutex> // for call_once
#include <unordered_map>
#include "paddle/fluid/framework/backward.h" #include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -32,7 +37,6 @@ limitations under the License. */ ...@@ -32,7 +37,6 @@ limitations under the License. */
#include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h" #include "paddle/fluid/operators/net_op.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
...@@ -100,6 +104,14 @@ PYBIND11_PLUGIN(core) { ...@@ -100,6 +104,14 @@ PYBIND11_PLUGIN(core) {
[](Tensor &self, paddle::platform::CUDAPlace &place) { [](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("alloc_int",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<int>(place);
})
.def("alloc_float",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<float>(place);
})
.def("set", PyCPUTensorSetFromArray<float>) .def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>) .def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>) .def("set", PyCPUTensorSetFromArray<double>)
...@@ -317,7 +329,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -317,7 +329,17 @@ All parameter, weight, gradient are variables in Paddle.
#else #else
return new paddle::platform::CUDADeviceContext(place); return new paddle::platform::CUDADeviceContext(place);
#endif #endif
}); })
.def_static("create",
[](paddle::platform::CUDAPinnedPlace& place)
-> paddle::platform::DeviceContext* {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPinnedPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDAPinnedDeviceContext(place);
#endif
});;
// clang-format on // clang-format on
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>()); py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
...@@ -330,6 +352,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -330,6 +352,10 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>()) .def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
.def(py::init<>())
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place") py::class_<platform::Place>(m, "Place")
.def(py::init<>()) .def(py::init<>())
.def("set_place", .def("set_place",
...@@ -339,7 +365,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -339,7 +365,11 @@ All parameter, weight, gradient are variables in Paddle.
.def("set_place", .def("set_place",
[](platform::Place &self, const platform::CUDAPlace &gpu_place) { [](platform::Place &self, const platform::CUDAPlace &gpu_place) {
self = gpu_place; self = gpu_place;
}); })
.def("set_place", [](platform::Place &self,
const platform::CUDAPinnedPlace &gpu_place) {
self = gpu_place;
});
py::class_<OperatorBase>(m, "Operator") py::class_<OperatorBase>(m, "Operator")
.def_static("create", .def_static("create",
...@@ -363,6 +393,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -363,6 +393,11 @@ All parameter, weight, gradient are variables in Paddle.
.def("run", .def("run",
[](OperatorBase &self, const Scope &scope, [](OperatorBase &self, const Scope &scope,
const platform::CUDAPlace &place) { self.Run(scope, place); }) const platform::CUDAPlace &place) { self.Run(scope, place); })
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CUDAPinnedPlace &place) {
self.Run(scope, place);
})
.def("type", .def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); }) [](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs", .def("outputs",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册