未验证 提交 6a5b7e59 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] add is_compiled_with_rocm api, test=develop (#33043)

上级 8c6bbb48
......@@ -258,6 +258,7 @@ from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401
from .device import get_device # noqa: F401
from .fluid.framework import is_compiled_with_cuda # noqa: F401
from .fluid.framework import is_compiled_with_rocm # noqa: F401
from .device import is_compiled_with_xpu # noqa: F401
from .device import is_compiled_with_npu # noqa: F401
from .device import XPUPlace # noqa: F401
......@@ -384,6 +385,7 @@ __all__ = [ #noqa
'less_equal',
'triu',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'sin',
'dist',
'unbind',
......
......@@ -19,6 +19,7 @@ from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.framework import is_compiled_with_cuda #DEFINE_ALIAS
from paddle.fluid.framework import is_compiled_with_rocm #DEFINE_ALIAS
__all__ = [
'get_cudnn_version',
......@@ -33,6 +34,7 @@ __all__ = [
# 'CUDAPinnedPlace',
# 'CUDAPlace',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_npu'
]
......
......@@ -53,6 +53,7 @@ __all__ = [
'cuda_pinned_places',
'in_dygraph_mode',
'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_xpu',
'Variable',
'require_version',
......@@ -398,6 +399,21 @@ def is_compiled_with_cuda():
return core.is_compiled_with_cuda()
def is_compiled_with_rocm():
"""
Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).
Returns (bool): `True` if ROCm is currently available, otherwise `False`.
Examples:
.. code-block:: python
import paddle
support_gpu = paddle.is_compiled_with_rocm()
"""
return core.is_compiled_with_rocm()
def cuda_places(device_ids=None):
"""
**Note**:
......
......@@ -42,10 +42,10 @@ if IS_WINDOWS and six.PY3:
from unittest.mock import Mock
_du_build_ext.get_export_symbols = Mock(return_value=None)
CUDA_HOME = find_cuda_home()
if core.is_compiled_with_rocm():
ROCM_HOME = find_rocm_home()
else:
CUDA_HOME = find_cuda_home()
CUDA_HOME = ROCM_HOME
def setup(**attr):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册