未验证 提交 ccafd2e5 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add eager mode support (#42034)

上级 0e0f7da6
......@@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array,
zero_copy);
} else if (platform::is_npu_place(place)) {
SetTensorFromPyArray<platform::NPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
SetTensorFromPyArray<platform::CustomPlace>(impl_ptr, array, place,
zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/CustomPlace"));
}
}
......
......@@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype;
extern PyTypeObject* g_xpuplace_pytype;
extern PyTypeObject* g_npuplace_pytype;
extern PyTypeObject* g_cudapinnedplace_pytype;
extern PyTypeObject* g_customplace_pytype;
extern PyTypeObject* g_framework_tensor_pytype;
extern PyTypeObject* g_framework_lodtensorarray_pytype;
extern PyTypeObject* g_custom_op_kernel_ctx_pytype;
......@@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_cudapinnedplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CUDAPinnedPlace>();
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_customplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CustomPlace>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), "
"one "
"of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace,"
"CustomPlace), "
"but got %s",
arg_pos + 1, reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
......
......@@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr;
PyTypeObject *g_npuplace_pytype = nullptr;
PyTypeObject *g_cudapinnedplace_pytype = nullptr;
PyTypeObject *g_mluplace_pytype = nullptr;
PyTypeObject *g_customplace_pytype = nullptr;
PyTypeObject *g_framework_tensor_pytype = nullptr;
PyTypeObject *g_framework_lodtensorarray_pytype = nullptr;
PyTypeObject *g_custom_op_kernel_ctx_pytype = nullptr;
......@@ -2125,8 +2126,8 @@ All parameter, weight, gradient are variables in Paddle.
#endif
return devices;
});
py::class_<platform::CustomPlace>(m, "CustomPlace",
R"DOC(
py::class_<platform::CustomPlace> customplace(m, "CustomPlace",
R"DOC(
CustomPlace is a descriptor of a device.
It represents a custom device on which a tensor will be allocated and a model will run.
......@@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle.
import paddle
fake_cpu_place = paddle.CustomPlace("FakeCPU", 0)
)DOC")
)DOC");
g_customplace_pytype = reinterpret_cast<PyTypeObject *>(customplace.ptr());
customplace
.def("__init__",
[](platform::CustomPlace &self, const std::string &device_type,
int dev_id) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册