提交 5e013d8c 编写于 作者: M Megvii Engine Team

refactor(xla): add xla acknowledgement

GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
上级 4c7905f3
...@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved. ...@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE 5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved. Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0 Terms of Apache License Version 2.0
--------------------------------------------------- ---------------------------------------------------
......
# some code of this directory is from jax: https://github.com/google/jax
from .build import build_xla from .build import build_xla
# code of this directory is mainly from jax: https://github.com/google/jax
try: try:
import mge_xlalib as mge_xlalib import mge_xlalib as mge_xlalib
except ModuleNotFoundError as err: except ModuleNotFoundError as err:
...@@ -9,8 +11,10 @@ except ModuleNotFoundError as err: ...@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise ModuleNotFoundError(msg) raise ModuleNotFoundError(msg)
import gc import gc
import pathlib import os
import platform import platform
import subprocess
import sys
import warnings import warnings
from typing import Optional from typing import Optional
...@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features() ...@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla xla_extension = xla_client._xla
pytree = xla_client._xla.pytree pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit # we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib pmap_lib = xla_client._xla.pmap_lib
...@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback) ...@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version: int = getattr(xla_client, "_version", 0) xla_extension_version: int = getattr(xla_client, "_version", 0)
mlir_api_version = xla_client.mlir_api_version mlir_api_version = xla_client.mlir_api_version
# Finds the CUDA install path
def _cuda_path() -> Optional[str]: def _find_cuda_root_dir() -> Optional[str]:
_mgexlalib_path = pathlib.Path(mge_xlalib.__file__).parent cuda_root_dir = os.environ.get("CUDA_ROOT_DIR")
path = _mgexlalib_path.parent / "nvidia" / "cuda_nvcc" if cuda_root_dir is None:
if path.is_dir(): try:
return str(path) which = "where" if sys.platform == "win32" else "which"
path = _mgexlalib_path / "cuda" with open(os.devnull, "w") as devnull:
if path.is_dir(): nvcc = (
return str(path) subprocess.check_output([which, "nvcc"], stderr=devnull)
return None .decode()
.rstrip("\r\n")
)
cuda_path = _cuda_path() cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
except Exception:
if sys.platform == "win32":
assert False, "xla not supported on windows"
else:
cuda_root_dir = "/usr/local/cuda"
if not os.path.exists(cuda_root_dir):
cuda_root_dir = None
return cuda_root_dir
cuda_path = _find_cuda_root_dir()
transfer_guard_lib = xla_client._xla.transfer_guard_lib transfer_guard_lib = xla_client._xla.transfer_guard_lib
# code of this file is mainly from jax: https://github.com/google/jax
import logging import logging
import os import os
import platform as py_platform import platform as py_platform
...@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS ...@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
flags.DEFINE_string( flags.DEFINE_string(
"jax_xla_backend", "", "Deprecated, please use --jax_platforms instead." "xla_backend", "", "Deprecated, please use --xla_platforms instead."
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_backend_target", "xla_backend_target",
os.getenv("JAX_BACKEND_TARGET", "").lower(), os.getenv("XLA_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.', 'Either "local" or "rpc:address" to connect to a remote service target.',
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_platform_name", "xla_platform_name",
os.getenv("JAX_PLATFORM_NAME", "").lower(), os.getenv("XLA_PLATFORM_NAME", "").lower(),
"Deprecated, please use --jax_platforms instead.", "Deprecated, please use --xla_platforms instead.",
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_disable_most_optimizations", "xla_disable_most_optimizations",
bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False), bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of " "Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.", "optimization is greater than that of running a less-optimized program.",
) )
flags.DEFINE_integer( flags.DEFINE_integer(
"jax_xla_profile_version", "xla_profile_version",
int_env("JAX_XLA_PROFILE_VERSION", 0), int_env("XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. " "Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to " "This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.", "support the remote compilation profile feature.",
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_cuda_visible_devices", "xla_cuda_visible_devices",
"all", "all",
'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'Restricts the set of CUDA devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.", "comma-separate list of integer device IDs.",
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_rocm_visible_devices", "xla_rocm_visible_devices",
"all", "all",
'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'Restricts the set of ROCM devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.", "comma-separate list of integer device IDs.",
) )
...@@ -69,22 +70,22 @@ def get_compile_options( ...@@ -69,22 +70,22 @@ def get_compile_options(
) -> xla_client.CompileOptions: ) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values. """Returns the compile options to use, as derived from flag values.
Args: Args:
num_replicas: Number of replicas for which to compile. num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile. num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment device_assignment: Optional ndarray of xla devices indicating the assignment
of logical replicas to physical devices (default inherited from of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`. `num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA. partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner. generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space. auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space. auto_spmd_partitioning search space.
""" """
compile_options = xla_client.CompileOptions() compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions compile_options.num_partitions = num_partitions
...@@ -130,12 +131,12 @@ def get_compile_options( ...@@ -130,12 +131,12 @@ def get_compile_options(
if cuda_path is not None: if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.jax_disable_most_optimizations: if FLAGS.xla_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0 debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version compile_options.profile_version = FLAGS.xla_profile_version
return compile_options return compile_options
...@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"): ...@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial( partial(
make_gpu_client, make_gpu_client,
platform_name="cuda", platform_name="cuda",
visible_devices_flag="jax_cuda_visible_devices", visible_devices_flag="xla_cuda_visible_devices",
), ),
priority=200, priority=200,
) )
...@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"): ...@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial( partial(
make_gpu_client, make_gpu_client,
platform_name="rocm", platform_name="rocm",
visible_devices_flag="jax_rocm_visible_devices", visible_devices_flag="xla_rocm_visible_devices",
), ),
priority=200, priority=200,
) )
if hasattr(xla_client, "make_plugin_device_client"): if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the # It is assumed that if xla has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the # user wants to use the plugin client by default. Therefore, it gets the
# highest priority. # highest priority.
register_backend_factory( register_backend_factory(
...@@ -229,11 +230,11 @@ def is_known_platform(platform: str): ...@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def canonicalize_platform(platform: str) -> str: def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent. """Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care. force users to care.
""" """
platforms = _alias_to_platforms.get(platform, None) platforms = _alias_to_platforms.get(platform, None)
if platforms is None: if platforms is None:
return platform return platform
...@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str: ...@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def expand_platform_alias(platform: str) -> List[str]: def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"]. """Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code. in many respects since they share most of the same code.
""" """
return _alias_to_platforms.get(platform, [platform]) return _alias_to_platforms.get(platform, [platform])
...@@ -270,11 +271,11 @@ def backends(): ...@@ -270,11 +271,11 @@ def backends():
with _backend_lock: with _backend_lock:
if _backends: if _backends:
return _backends return _backends
if config.jax_platforms: if config.xla_platforms:
jax_platforms = config.jax_platforms.split(",") xla_platforms = config.xla_platforms.split(",")
platforms = [] platforms = []
# Allow platform aliases in the list of platforms. # Allow platform aliases in the list of platforms.
for platform in jax_platforms: for platform in xla_platforms:
platforms.extend(expand_platform_alias(platform)) platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1) priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities) platforms_and_priorites = zip(platforms, priorities)
...@@ -303,8 +304,8 @@ def backends(): ...@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices, # If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError. # we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}" err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms: if config.xla_platforms:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
else: else:
_backends_errors[platform] = str(err) _backends_errors[platform] = str(err)
...@@ -315,12 +316,9 @@ def backends(): ...@@ -315,12 +316,9 @@ def backends():
if ( if (
py_platform.system() != "Darwin" py_platform.system() != "Darwin"
and _default_backend.platform == "cpu" and _default_backend.platform == "cpu"
and FLAGS.jax_platform_name != "cpu" and FLAGS.xla_platform_name != "cpu"
): ):
logger.warning( logger.warning("No GPU/TPU found, falling back to CPU. ")
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
return _backends return _backends
...@@ -329,7 +327,7 @@ def _clear_backends(): ...@@ -329,7 +327,7 @@ def _clear_backends():
global _backends_errors global _backends_errors
global _default_backend global _default_backend
logger.info("Clearing JAX backend caches.") logger.info("Clearing XLA backend caches.")
with _backend_lock: with _backend_lock:
_backends = {} _backends = {}
_backends_errors = {} _backends_errors = {}
...@@ -351,11 +349,6 @@ def _init_backend(platform): ...@@ -351,11 +349,6 @@ def _init_backend(platform):
raise RuntimeError(f"Could not initialize backend '{platform}'") raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0: if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.") raise RuntimeError(f"Backend '{platform}' provides no devices.")
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger.debug("Backend '%s' initialized", platform) logger.debug("Backend '%s' initialized", platform)
return backend return backend
...@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None): ...@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if not isinstance(platform, (type(None), str)): if not isinstance(platform, (type(None), str)):
return platform return platform
platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None
bs = backends() bs = backends()
if platform is not None: if platform is not None:
...@@ -399,7 +392,7 @@ def get_device_backend(device=None): ...@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices. """Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`. On most platforms, this is the same as :py:func:`xla.local_device_count`.
However, on multi-process platforms where different devices are associated However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across with different processes, this will return the total number of devices across
all processes. all processes.
...@@ -430,7 +423,7 @@ def devices( ...@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by comparing :attr:`Device.process_index` to the value returned by
:py:func:`jax.process_index`. :py:func:`xla.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend. If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available, The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
...@@ -457,13 +450,13 @@ def local_devices( ...@@ -457,13 +450,13 @@ def local_devices(
backend: Optional[Union[str, XlaBackend]] = None, backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None, host_id: Optional[int] = None,
) -> List[xla_client.Device]: ) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process. """Like :py:func:`xla.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process. If ``process_index`` is ``None``, returns devices local to this process.
Args: Args:
process_index: the integer index of the process. Process indices can be process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``. retrieved via ``len(xla.process_count())``.
backend: This is an experimental feature and the API is likely to change. backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``. ``'tpu'``.
...@@ -473,7 +466,7 @@ def local_devices( ...@@ -473,7 +466,7 @@ def local_devices(
""" """
if host_id is not None: if host_id is not None:
warnings.warn( warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to " "The argument to xla.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update " "`process_index`. This alias will eventually be removed; please update "
"your code." "your code."
) )
...@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int: ...@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return get_backend(backend).process_index() return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None): def host_id(backend=None):
warnings.warn( warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias " "xla.host_id has been renamed to xla.process_index. This alias "
"will eventually be removed; please update your code." "will eventually be removed; please update your code."
) )
return process_index(backend) return process_index(backend)
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of JAX processes associated with the backend.""" """Returns the number of XLA processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1 return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None): def host_count(backend=None):
warnings.warn( warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias " "xla.host_count has been renamed to xla.process_count. This alias "
"will eventually be removed; please update your code." "will eventually be removed; please update your code."
) )
return process_count(backend) return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None): def host_ids(backend=None):
warnings.warn( warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) " "xla.host_ids has been deprecated; please use range(xla.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your " "instead. xla.host_ids will eventually be removed; please update your "
"code." "code."
) )
return list(range(process_count(backend))) return list(range(process_count(backend)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册