From df311526ddd809a5d9b5082c9dbc2f65a6299066 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 10 Jul 2023 17:29:37 +0800 Subject: [PATCH] [CustomDevice] add custom device support for Variable.set_value (#55272) --- paddle/fluid/pybind/place.cc | 2 ++ python/paddle/fluid/framework.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index c97bba9be8f..98b7609578e 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -640,6 +640,8 @@ void BindPlace(pybind11::module &m) { // NOLINT .def("ipu_device_id", [](platform::Place &self) { return self.device; }) .def("custom_device_id", [](platform::Place &self) { return self.device; }) + .def("custom_device_type", + [](platform::Place &self) { return self.GetDeviceType(); }) .def("set_place", [](platform::Place &self, const platform::Place &other) { self = other; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8a461760ef0..ec3f235a18b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2456,6 +2456,12 @@ class Variable(metaclass=VariableMetaClass): p = core.Place() p.set_place(t._place()) place = core.XPUPlace(p.xpu_device_id()) + elif p.is_custom_place(): + p = core.Place() + p.set_place(t._place()) + place = core.CustomPlace( + p.custom_device_type(), p.custom_device_id() + ) else: p = core.Place() p.set_place(t._place()) -- GitLab