未验证 提交 bd165b94 编写于 作者: D duanyanhui 提交者: GitHub

[Custom Device] update get_device to custom and add custom_device api (#49721)

* update get_device to custom

* add custom_device api

* rm is_compiled_with_custom_device from framework

* add todo comments
上级 561f9013
......@@ -284,6 +284,20 @@ bool IsCompiledWithNPU() {
#endif
}
bool IsCompiledWithCustomDevice(std::string device_type) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
return false;
#else
std::vector<std::string> device_types;
device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
if (std::count(device_types.begin(), device_types.end(), device_type)) {
return true;
} else {
return false;
}
#endif
}
bool IsCompiledWithIPU() {
#ifndef PADDLE_WITH_IPU
return false;
......@@ -1559,6 +1573,25 @@ All parameter, weight, gradient are variables in Paddle.
#endif
return devices;
});
m.def("get_custom_device_count", [](const std::string &device_type) {
size_t device_count = 0;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
// TODO(duanyanhui): Optimize DeviceManager::GetDeviceCount to support
// returning default device when only one device is registered in
// DeviceManager.
device_count = phi::DeviceManager::GetDeviceCount(device_type);
#else
VLOG(1) << string::Sprintf(
"Cannot use get_custom_device_count because you have "
"installed"
"CPU/GPU version PaddlePaddle.\n"
"If you want to use get_custom_device_count, please try to "
"install"
"CustomDevice version "
"PaddlePaddle by: pip install paddlepaddle\n");
#endif
return device_count;
});
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
......@@ -1806,6 +1839,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM);
m.def("is_compiled_with_npu", IsCompiledWithNPU);
m.def("is_compiled_with_custom_device", IsCompiledWithCustomDevice);
m.def("is_compiled_with_ipu", IsCompiledWithIPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
......
......@@ -368,6 +368,7 @@ from .device import is_compiled_with_mlu # noqa: F401
from .device import is_compiled_with_cinn # noqa: F401
from .device import is_compiled_with_cuda # noqa: F401
from .device import is_compiled_with_rocm # noqa: F401
from .device import is_compiled_with_custom_device # noqa: F401
from .device import XPUPlace # noqa: F401
# high-level api
......
......@@ -38,6 +38,7 @@ __all__ = [ # noqa
'is_compiled_with_rocm',
'is_compiled_with_npu',
'is_compiled_with_mlu',
'is_compiled_with_custom_device',
'get_all_device_type',
'get_all_custom_device_type',
'get_available_device',
......@@ -65,6 +66,24 @@ def is_compiled_with_npu():
return core.is_compiled_with_npu()
def is_compiled_with_custom_device(device_type):
"""
Whether paddle was built with Paddle_CUSTOM_DEVICE .
Args:
std::string, the registered device type, like "npu".
Return:
bool, ``True`` if CustomDevice is supported, otherwise ``False``.
Examples:
.. code-block:: python
import paddle
support_npu = paddle.device.is_compiled_with_custom_device("npu")
"""
return core.is_compiled_with_custom_device(device_type)
def is_compiled_with_ipu():
"""
Whether paddle was built with WITH_IPU=ON to support Graphcore IPU.
......
......@@ -646,6 +646,21 @@ def _current_expected_place():
"You are using MLU version Paddle, but your MLU device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
elif core.is_compiled_with_custom_device("npu"):
# TODO(duanyanhui): Optimize DeviceManager and Return all expected places when device registered in DeviceManager is greater than 1.
try:
device_count = core.get_custom_device_count("npu")
except Exception as e:
device_count = 0
if device_count > 0:
_global_expected_place_ = core.CustomPlace(
"npu", _custom_device_ids("npu")[0]
)
else:
warnings.warn(
"You are using NPU version Paddle, but your NPU device is not set properly. CPU device will be used by default."
)
_global_expected_place_ = core.CPUPlace()
else:
_global_expected_place_ = core.CPUPlace()
......@@ -725,6 +740,15 @@ def _npu_ids():
return device_ids
def _custom_device_ids(device_type):
custom_devices_env = os.getenv("FLAGS_selected_" + device_type + "s")
if custom_devices_env:
device_ids = [int(s) for s in custom_devices_env.split(",")]
else:
device_ids = range(core.get_custom_device_count(device_type))
return device_ids
def _mlu_ids():
mlus_env = os.getenv("FLAGS_selected_mlus")
if mlus_env:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册