未验证 提交 6a5b7e59 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] add is_compiled_with_rocm api, test=develop (#33043)

上级 8c6bbb48
...@@ -258,6 +258,7 @@ from .device import get_cudnn_version # noqa: F401 ...@@ -258,6 +258,7 @@ from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401 from .device import set_device # noqa: F401
from .device import get_device # noqa: F401 from .device import get_device # noqa: F401
from .fluid.framework import is_compiled_with_cuda # noqa: F401 from .fluid.framework import is_compiled_with_cuda # noqa: F401
from .fluid.framework import is_compiled_with_rocm # noqa: F401
from .device import is_compiled_with_xpu # noqa: F401 from .device import is_compiled_with_xpu # noqa: F401
from .device import is_compiled_with_npu # noqa: F401 from .device import is_compiled_with_npu # noqa: F401
from .device import XPUPlace # noqa: F401 from .device import XPUPlace # noqa: F401
...@@ -384,6 +385,7 @@ __all__ = [ #noqa ...@@ -384,6 +385,7 @@ __all__ = [ #noqa
'less_equal', 'less_equal',
'triu', 'triu',
'is_compiled_with_cuda', 'is_compiled_with_cuda',
'is_compiled_with_rocm',
'sin', 'sin',
'dist', 'dist',
'unbind', 'unbind',
......
...@@ -19,6 +19,7 @@ from paddle.fluid import core ...@@ -19,6 +19,7 @@ from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.framework import is_compiled_with_cuda #DEFINE_ALIAS from paddle.fluid.framework import is_compiled_with_cuda #DEFINE_ALIAS
from paddle.fluid.framework import is_compiled_with_rocm #DEFINE_ALIAS
__all__ = [ __all__ = [
'get_cudnn_version', 'get_cudnn_version',
...@@ -33,6 +34,7 @@ __all__ = [ ...@@ -33,6 +34,7 @@ __all__ = [
# 'CUDAPinnedPlace', # 'CUDAPinnedPlace',
# 'CUDAPlace', # 'CUDAPlace',
'is_compiled_with_cuda', 'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_npu' 'is_compiled_with_npu'
] ]
......
...@@ -53,6 +53,7 @@ __all__ = [ ...@@ -53,6 +53,7 @@ __all__ = [
'cuda_pinned_places', 'cuda_pinned_places',
'in_dygraph_mode', 'in_dygraph_mode',
'is_compiled_with_cuda', 'is_compiled_with_cuda',
'is_compiled_with_rocm',
'is_compiled_with_xpu', 'is_compiled_with_xpu',
'Variable', 'Variable',
'require_version', 'require_version',
...@@ -398,6 +399,21 @@ def is_compiled_with_cuda(): ...@@ -398,6 +399,21 @@ def is_compiled_with_cuda():
return core.is_compiled_with_cuda() return core.is_compiled_with_cuda()
def is_compiled_with_rocm():
"""
Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).
Returns (bool): `True` if ROCm is currently available, otherwise `False`.
Examples:
.. code-block:: python
import paddle
support_gpu = paddle.is_compiled_with_rocm()
"""
return core.is_compiled_with_rocm()
def cuda_places(device_ids=None): def cuda_places(device_ids=None):
""" """
**Note**: **Note**:
......
...@@ -42,10 +42,10 @@ if IS_WINDOWS and six.PY3: ...@@ -42,10 +42,10 @@ if IS_WINDOWS and six.PY3:
from unittest.mock import Mock from unittest.mock import Mock
_du_build_ext.get_export_symbols = Mock(return_value=None) _du_build_ext.get_export_symbols = Mock(return_value=None)
CUDA_HOME = find_cuda_home()
if core.is_compiled_with_rocm(): if core.is_compiled_with_rocm():
ROCM_HOME = find_rocm_home() ROCM_HOME = find_rocm_home()
else: CUDA_HOME = ROCM_HOME
CUDA_HOME = find_cuda_home()
def setup(**attr): def setup(**attr):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册