refactor(xla): add xla acknowledgement

GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
......@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0
---------------------------------------------------
......
# some code of this directory is from jax: https://github.com/google/jax
from .build import build_xla
# code of this directory is mainly from jax: https://github.com/google/jax
try:
import mge_xlalib as mge_xlalib
except ModuleNotFoundError as err:
......@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise ModuleNotFoundError(msg)
import gc
import pathlib
import os
import platform
import subprocess
import sys
import warnings
from typing import Optional
......@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
# we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
......@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version: int = getattr(xla_client, "_version", 0)
mlir_api_version = xla_client.mlir_api_version
def _cuda_path() -> Optional[str]:
_mgexlalib_path = pathlib.Path(mge_xlalib.__file__).parent
path = _mgexlalib_path.parent / "nvidia" / "cuda_nvcc"
if path.is_dir():
return str(path)
path = _mgexlalib_path / "cuda"
if path.is_dir():
return str(path)
return None
cuda_path = _cuda_path()
# Finds the CUDA install path
def _find_cuda_root_dir() -> Optional[str]:
cuda_root_dir = os.environ.get("CUDA_ROOT_DIR")
if cuda_root_dir is None:
try:
which = "where" if sys.platform == "win32" else "which"
with open(os.devnull, "w") as devnull:
nvcc = (
subprocess.check_output([which, "nvcc"], stderr=devnull)
.decode()
.rstrip("\r\n")
)
cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
except Exception:
if sys.platform == "win32":
assert False, "xla not supported on windows"
else:
cuda_root_dir = "/usr/local/cuda"
if not os.path.exists(cuda_root_dir):
cuda_root_dir = None
return cuda_root_dir
cuda_path = _find_cuda_root_dir()
transfer_guard_lib = xla_client._xla.transfer_guard_lib
# code of this file is mainly from jax: https://github.com/google/jax
import logging
import os
import platform as py_platform
......@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger = logging.getLogger(__name__)
flags.DEFINE_string(
"jax_xla_backend", "", "Deprecated, please use --jax_platforms instead."
"xla_backend", "", "Deprecated, please use --xla_platforms instead."
)
flags.DEFINE_string(
"jax_backend_target",
os.getenv("JAX_BACKEND_TARGET", "").lower(),
"xla_backend_target",
os.getenv("XLA_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.',
)
flags.DEFINE_string(
"jax_platform_name",
os.getenv("JAX_PLATFORM_NAME", "").lower(),
"Deprecated, please use --jax_platforms instead.",
"xla_platform_name",
os.getenv("XLA_PLATFORM_NAME", "").lower(),
"Deprecated, please use --xla_platforms instead.",
)
flags.DEFINE_bool(
"jax_disable_most_optimizations",
bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False),
"xla_disable_most_optimizations",
bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.",
)
flags.DEFINE_integer(
"jax_xla_profile_version",
int_env("JAX_XLA_PROFILE_VERSION", 0),
"xla_profile_version",
int_env("XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.",
)
flags.DEFINE_string(
"jax_cuda_visible_devices",
"xla_cuda_visible_devices",
"all",
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'Restricts the set of CUDA devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
flags.DEFINE_string(
"jax_rocm_visible_devices",
"xla_rocm_visible_devices",
"all",
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'Restricts the set of ROCM devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
......@@ -69,22 +70,22 @@ def get_compile_options(
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of xla devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
......@@ -130,12 +131,12 @@ def get_compile_options(
if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.jax_disable_most_optimizations:
if FLAGS.xla_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version
compile_options.profile_version = FLAGS.xla_profile_version
return compile_options
......@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag="jax_cuda_visible_devices",
visible_devices_flag="xla_cuda_visible_devices",
),
priority=200,
)
......@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag="jax_rocm_visible_devices",
visible_devices_flag="xla_rocm_visible_devices",
),
priority=200,
)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the
# It is assumed that if xla has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory(
......@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms = _alias_to_platforms.get(platform, None)
if platforms is None:
return platform
......@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])
......@@ -270,11 +271,11 @@ def backends():
with _backend_lock:
if _backends:
return _backends
if config.jax_platforms:
jax_platforms = config.jax_platforms.split(",")
if config.xla_platforms:
xla_platforms = config.xla_platforms.split(",")
platforms = []
# Allow platform aliases in the list of platforms.
for platform in jax_platforms:
for platform in xla_platforms:
platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities)
......@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
if config.xla_platforms:
err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
......@@ -315,12 +316,9 @@ def backends():
if (
py_platform.system() != "Darwin"
and _default_backend.platform == "cpu"
and FLAGS.jax_platform_name != "cpu"
and FLAGS.xla_platform_name != "cpu"
):
logger.warning(
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
logger.warning("No GPU/TPU found, falling back to CPU. ")
return _backends
......@@ -329,7 +327,7 @@ def _clear_backends():
global _backends_errors
global _default_backend
logger.info("Clearing JAX backend caches.")
logger.info("Clearing XLA backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
......@@ -351,11 +349,6 @@ def _init_backend(platform):
raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.")
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger.debug("Backend '%s' initialized", platform)
return backend
......@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if not isinstance(platform, (type(None), str)):
return platform
platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None
platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None
bs = backends()
if platform is not None:
......@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
On most platforms, this is the same as :py:func:`xla.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
......@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
:py:func:`xla.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
......@@ -457,13 +450,13 @@ def local_devices(
backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None,
) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
"""Like :py:func:`xla.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
retrieved via ``len(xla.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
......@@ -473,7 +466,7 @@ def local_devices(
"""
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"The argument to xla.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
......@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None):
warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"xla.host_id has been renamed to xla.process_index. This alias "
"will eventually be removed; please update your code."
)
return process_index(backend)
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of JAX processes associated with the backend."""
"""Returns the number of XLA processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"xla.host_count has been renamed to xla.process_count. This alias "
"will eventually be removed; please update your code."
)
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"xla.host_ids has been deprecated; please use range(xla.process_count()) "
"instead. xla.host_ids will eventually be removed; please update your "
"code."
)
return list(range(process_count(backend)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部