未验证 提交 ccafd2e5 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add eager mode support (#42034)

上级 0e0f7da6
...@@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array, ...@@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array,
zero_copy); zero_copy);
} else if (platform::is_npu_place(place)) { } else if (platform::is_npu_place(place)) {
SetTensorFromPyArray<platform::NPUPlace>(impl_ptr, array, place, zero_copy); SetTensorFromPyArray<platform::NPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
SetTensorFromPyArray<platform::CustomPlace>(impl_ptr, array, place,
zero_copy);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of " "Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace")); "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/CustomPlace"));
} }
} }
......
...@@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype; ...@@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype;
extern PyTypeObject* g_xpuplace_pytype; extern PyTypeObject* g_xpuplace_pytype;
extern PyTypeObject* g_npuplace_pytype; extern PyTypeObject* g_npuplace_pytype;
extern PyTypeObject* g_cudapinnedplace_pytype; extern PyTypeObject* g_cudapinnedplace_pytype;
extern PyTypeObject* g_customplace_pytype;
extern PyTypeObject* g_framework_tensor_pytype; extern PyTypeObject* g_framework_tensor_pytype;
extern PyTypeObject* g_framework_lodtensorarray_pytype; extern PyTypeObject* g_framework_lodtensorarray_pytype;
extern PyTypeObject* g_custom_op_kernel_ctx_pytype; extern PyTypeObject* g_custom_op_kernel_ctx_pytype;
...@@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) { ...@@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
} else if (PyObject_IsInstance( } else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_cudapinnedplace_pytype))) { obj, reinterpret_cast<PyObject*>(g_cudapinnedplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CUDAPinnedPlace>(); place = ::pybind11::handle(obj).cast<platform::CUDAPinnedPlace>();
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_customplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CustomPlace>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), " "one "
"of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace,"
"CustomPlace), "
"but got %s", "but got %s",
arg_pos + 1, reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name)); arg_pos + 1, reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
} }
......
...@@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr; ...@@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr;
PyTypeObject *g_npuplace_pytype = nullptr; PyTypeObject *g_npuplace_pytype = nullptr;
PyTypeObject *g_cudapinnedplace_pytype = nullptr; PyTypeObject *g_cudapinnedplace_pytype = nullptr;
PyTypeObject *g_mluplace_pytype = nullptr; PyTypeObject *g_mluplace_pytype = nullptr;
PyTypeObject *g_customplace_pytype = nullptr;
PyTypeObject *g_framework_tensor_pytype = nullptr; PyTypeObject *g_framework_tensor_pytype = nullptr;
PyTypeObject *g_framework_lodtensorarray_pytype = nullptr; PyTypeObject *g_framework_lodtensorarray_pytype = nullptr;
PyTypeObject *g_custom_op_kernel_ctx_pytype = nullptr; PyTypeObject *g_custom_op_kernel_ctx_pytype = nullptr;
...@@ -2125,7 +2126,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2125,7 +2126,7 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
return devices; return devices;
}); });
py::class_<platform::CustomPlace>(m, "CustomPlace", py::class_<platform::CustomPlace> customplace(m, "CustomPlace",
R"DOC( R"DOC(
CustomPlace is a descriptor of a device. CustomPlace is a descriptor of a device.
It represents a custom device on which a tensor will be allocated and a model will run. It represents a custom device on which a tensor will be allocated and a model will run.
...@@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle.
import paddle import paddle
fake_cpu_place = paddle.CustomPlace("FakeCPU", 0) fake_cpu_place = paddle.CustomPlace("FakeCPU", 0)
)DOC") )DOC");
g_customplace_pytype = reinterpret_cast<PyTypeObject *>(customplace.ptr());
customplace
.def("__init__", .def("__init__",
[](platform::CustomPlace &self, const std::string &device_type, [](platform::CustomPlace &self, const std::string &device_type,
int dev_id) { int dev_id) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册