未验证 提交 bd165b94 编写于 作者: D duanyanhui 提交者: GitHub

[Custom Device] update get_device to custom and add custom_device api (#49721)

* update get_device to custom

* add custom_device api

* rm is_compiled_with_custom_device from framework

* add todo comments
上级 561f9013
...@@ -284,6 +284,20 @@ bool IsCompiledWithNPU() { ...@@ -284,6 +284,20 @@ bool IsCompiledWithNPU() {
#endif #endif
} }
bool IsCompiledWithCustomDevice(std::string device_type) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
return false;
#else
std::vector<std::string> device_types;
device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
if (std::count(device_types.begin(), device_types.end(), device_type)) {
return true;
} else {
return false;
}
#endif
}
bool IsCompiledWithIPU() { bool IsCompiledWithIPU() {
#ifndef PADDLE_WITH_IPU #ifndef PADDLE_WITH_IPU
return false; return false;
...@@ -1559,6 +1573,25 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1559,6 +1573,25 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
return devices; return devices;
}); });
m.def("get_custom_device_count", [](const std::string &device_type) {
size_t device_count = 0;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
// TODO(duanyanhui): Optimize DeviceManager::GetDeviceCount to support
// returning default device when only one device is registered in
// DeviceManager.
device_count = phi::DeviceManager::GetDeviceCount(device_type);
#else
VLOG(1) << string::Sprintf(
"Cannot use get_custom_device_count because you have "
"installed"
"CPU/GPU version PaddlePaddle.\n"
"If you want to use get_custom_device_count, please try to "
"install"
"CustomDevice version "
"PaddlePaddle by: pip install paddlepaddle\n");
#endif
return device_count;
});
py::class_<OperatorBase>(m, "Operator") py::class_<OperatorBase>(m, "Operator")
.def_static("create", .def_static("create",
...@@ -1806,6 +1839,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1806,6 +1839,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM); m.def("is_compiled_with_rocm", IsCompiledWithROCM);
m.def("is_compiled_with_npu", IsCompiledWithNPU); m.def("is_compiled_with_npu", IsCompiledWithNPU);
m.def("is_compiled_with_custom_device", IsCompiledWithCustomDevice);
m.def("is_compiled_with_ipu", IsCompiledWithIPU); m.def("is_compiled_with_ipu", IsCompiledWithIPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_xpu", IsCompiledWithXPU);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
......
...@@ -368,6 +368,7 @@ from .device import is_compiled_with_mlu # noqa: F401 ...@@ -368,6 +368,7 @@ from .device import is_compiled_with_mlu # noqa: F401
from .device import is_compiled_with_cinn # noqa: F401 from .device import is_compiled_with_cinn # noqa: F401
from .device import is_compiled_with_cuda # noqa: F401 from .device import is_compiled_with_cuda # noqa: F401
from .device import is_compiled_with_rocm # noqa: F401 from .device import is_compiled_with_rocm # noqa: F401
from .device import is_compiled_with_custom_device # noqa: F401
from .device import XPUPlace # noqa: F401 from .device import XPUPlace # noqa: F401
# high-level api # high-level api
......
...@@ -38,6 +38,7 @@ __all__ = [ # noqa ...@@ -38,6 +38,7 @@ __all__ = [ # noqa
'is_compiled_with_rocm', 'is_compiled_with_rocm',
'is_compiled_with_npu', 'is_compiled_with_npu',
'is_compiled_with_mlu', 'is_compiled_with_mlu',
'is_compiled_with_custom_device',
'get_all_device_type', 'get_all_device_type',
'get_all_custom_device_type', 'get_all_custom_device_type',
'get_available_device', 'get_available_device',
...@@ -65,6 +66,24 @@ def is_compiled_with_npu(): ...@@ -65,6 +66,24 @@ def is_compiled_with_npu():
return core.is_compiled_with_npu() return core.is_compiled_with_npu()
def is_compiled_with_custom_device(device_type):
"""
Whether paddle was built with Paddle_CUSTOM_DEVICE .
Args:
std::string, the registered device type, like "npu".
Return:
bool, ``True`` if CustomDevice is supported, otherwise ``False``.
Examples:
.. code-block:: python
import paddle
support_npu = paddle.device.is_compiled_with_custom_device("npu")
"""
return core.is_compiled_with_custom_device(device_type)
def is_compiled_with_ipu(): def is_compiled_with_ipu():
""" """
Whether paddle was built with WITH_IPU=ON to support Graphcore IPU. Whether paddle was built with WITH_IPU=ON to support Graphcore IPU.
......
...@@ -646,6 +646,21 @@ def _current_expected_place(): ...@@ -646,6 +646,21 @@ def _current_expected_place():
"You are using MLU version Paddle, but your MLU device is not set properly. CPU device will be used by default." "You are using MLU version Paddle, but your MLU device is not set properly. CPU device will be used by default."
) )
_global_expected_place_ = core.CPUPlace() _global_expected_place_ = core.CPUPlace()
elif core.is_compiled_with_custom_device("npu"):
# TODO(duanyanhui): Optimize DeviceManager and Return all expected places when device registered in DeviceManager is greater than 1.
try:
device_count = core.get_custom_device_count("npu")
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.CustomPlace(
"npu", _custom_device_ids("npu")[0]
)
else:
warnings.warn(
"You are using NPU version Paddle, but your NPU device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
else: else:
_global_expected_place_ = core.CPUPlace() _global_expected_place_ = core.CPUPlace()
...@@ -725,6 +740,15 @@ def _npu_ids(): ...@@ -725,6 +740,15 @@ def _npu_ids():
return device_ids return device_ids
def _custom_device_ids(device_type):
custom_devices_env = os.getenv("FLAGS_selected_" + device_type + "s")
if custom_devices_env:
device_ids = [int(s) for s in custom_devices_env.split(",")]
else:
device_ids = range(core.get_custom_device_count(device_type))
return device_ids
def _mlu_ids(): def _mlu_ids():
mlus_env = os.getenv("FLAGS_selected_mlus") mlus_env = os.getenv("FLAGS_selected_mlus")
if mlus_env: if mlus_env:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册