diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index c97bba9be8f2fd93f44f176613988c6bad1ec0b1..98b7609578e5308af0ffb14e7bd9c685e30d31a8 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -640,6 +640,8 @@ void BindPlace(pybind11::module &m) { // NOLINT .def("ipu_device_id", [](platform::Place &self) { return self.device; }) .def("custom_device_id", [](platform::Place &self) { return self.device; }) + .def("custom_device_type", + [](platform::Place &self) { return self.GetDeviceType(); }) .def("set_place", [](platform::Place &self, const platform::Place &other) { self = other; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8a461760ef0c9be7868c874fa4a1d9dd326b6066..ec3f235a18b8ebd2ccba276e3b2bb02cbd0d20de 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2456,6 +2456,12 @@ class Variable(metaclass=VariableMetaClass): p = core.Place() p.set_place(t._place()) place = core.XPUPlace(p.xpu_device_id()) + elif p.is_custom_place(): + p = core.Place() + p.set_place(t._place()) + place = core.CustomPlace( + p.custom_device_type(), p.custom_device_id() + ) else: p = core.Place() p.set_place(t._place())