From 8225a6a1d90a6a667aec12661ae5d17a2c052c4e Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Wed, 30 Jun 2021 11:30:02 +0800 Subject: [PATCH] [NPU] support set_device (#33815) * support set_device for NPU. * minor update doc and add more unit test. --- paddle/fluid/imperative/tracer.cc | 8 +++++ paddle/fluid/pybind/pybind.cc | 2 ++ python/paddle/device.py | 32 +++++++++++++++---- .../fluid/tests/unittests/test_device.py | 20 ++++++++++++ 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index a8ca788d3b6..3d97d68b5c7 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -194,6 +194,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with XPU if use XPUPlace.")); +#endif + } else if (platform::is_npu_place(place)) { +#ifdef PADDLE_WITH_ASCEND_CL + platform::SetNPUDeviceId( + BOOST_GET_CONST(platform::NPUPlace, place).device); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with NPU if use NPUPlace.")); #endif } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 883ade66d4f..a93ce4ecd48 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1718,6 +1718,8 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("get_device_id", + [](const platform::NPUPlace &self) { return self.GetDeviceId(); }) .def("__str__", string::to_string); py::class_(m, "Place") diff --git a/python/paddle/device.py b/python/paddle/device.py index 93e439ecf0a..cf445917dd5 100644 --- a/python/paddle/device.py +++ b/python/paddle/device.py @@ -133,12 +133,20 @@ def _convert_to_place(device): selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",") device_id = int(selected_xpus[0]) place = core.XPUPlace(device_id) + elif lower_device == 'npu': + if not core.is_compiled_with_npu(): + raise ValueError("The device should not be 'npu', " + "since PaddlePaddle is not compiled with NPU") + selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",") + device_id = int(selected_npus[0]) + place = core.NPUPlace(device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) - if not avaliable_gpu_device and not avaliable_xpu_device: + avaliable_npu_device = re.match(r'npu:\d+', lower_device) + if not avaliable_gpu_device and not avaliable_xpu_device and not avaliable_npu_device: raise ValueError( - "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu' or 'xpu:x'" + "The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu' or 'npu:x'" ) if avaliable_gpu_device: if not core.is_compiled_with_cuda(): @@ -158,19 +166,28 @@ def _convert_to_place(device): device_id = device_info_list[1] device_id = int(device_id) place = core.XPUPlace(device_id) + if avaliable_npu_device: + if not core.is_compiled_with_npu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " + "not compiled with NPU".format(avaliable_npu_device)) + device_info_list = device.split(':', 1) + device_id = device_info_list[1] + device_id = int(device_id) + place = core.NPUPlace(device_id) return place def set_device(device): """ - Paddle supports running calculations on various types of devices, including CPU, GPU and XPU. + Paddle supports running calculations on various types of devices, including CPU, GPU, XPU and NPU. They are represented by string identifiers. This function can specify the global device which the OP will run. Parameters: device(str): This parameter determines the specific running device. - It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the - index of the GPUs or XPUs. + It can be ``cpu``, ``gpu``, ``xpu``, ``npu``, ``gpu:x``, ``xpu:x`` and ``npu:x``, + where ``x`` is the index of the GPUs, XPUs or NPUs. Examples: @@ -191,7 +208,7 @@ def set_device(device): def get_device(): """ This funciton can get the current global device of the program is running. - It's a string which is like 'cpu', 'gpu:x' and 'xpu:x'. if the global device is not + It's a string which is like 'cpu', 'gpu:x', 'xpu:x' and 'npu:x'. if the global device is not set, it will return a string which is 'gpu:x' when cuda is avaliable or it will return a string which is 'cpu' when cuda is not avaliable. @@ -213,5 +230,8 @@ def get_device(): elif isinstance(place, core.XPUPlace): device_id = place.get_device_id() device = 'xpu:' + str(device_id) + elif isinstance(place, core.NPUPlace): + device_id = place.get_device_id() + device = 'npu:' + str(device_id) return device diff --git a/python/paddle/fluid/tests/unittests/test_device.py b/python/paddle/fluid/tests/unittests/test_device.py index 08697a08044..fc3734c7874 100644 --- a/python/paddle/fluid/tests/unittests/test_device.py +++ b/python/paddle/fluid/tests/unittests/test_device.py @@ -49,6 +49,10 @@ class TestStaticDeviceManage(unittest.TestCase): if core.is_compiled_with_xpu(): self._test_device("xpu:0", core.XPUPlace) + def test_npu_device(self): + if core.is_compiled_with_npu(): + self._test_device("npu:0", core.NPUPlace) + class TestImperativeDeviceManage(unittest.TestCase): def test_cpu(self): @@ -87,6 +91,22 @@ class TestImperativeDeviceManage(unittest.TestCase): self.assertTrue(out.place.is_xpu_place()) self.assertEqual(device, "xpu:0") + def test_npu(self): + if core.is_compiled_with_npu(): + with fluid.dygraph.guard(): + paddle.set_device('npu:0') + out1 = paddle.zeros(shape=[1, 3], dtype='float32') + out2 = paddle.ones(shape=[1, 3], dtype='float32') + out3 = paddle.concat(x=[out1, out2], axis=0) + device = paddle.get_device() + self.assertEqual( + isinstance(framework._current_expected_place(), + core.NPUPlace), True) + self.assertTrue(out1.place.is_npu_place()) + self.assertTrue(out2.place.is_npu_place()) + self.assertTrue(out3.place.is_npu_place()) + self.assertEqual(device, "npu:0") + if __name__ == '__main__': unittest.main() -- GitLab