From cf139400b725099f53536fce8655bbde0774ba89 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 13 Jun 2023 20:22:27 +0800 Subject: [PATCH] Revert "feat(imperative): add xla support" This reverts commit 9dd518c76cde32e81d29562be4a4bc176354ee54. GitOrigin-RevId: 5b49dab0570842c01019d333d242d36b7517fee4 --- imperative/python/megengine/xla/__init__.py | 1 - imperative/python/megengine/xla/build.py | 118 -- imperative/python/megengine/xla/compile.py | 872 ----------- imperative/python/megengine/xla/device.py | 86 -- imperative/python/megengine/xla/distribute.py | 91 -- imperative/python/megengine/xla/dtype.py | 109 -- imperative/python/megengine/xla/ir_utils.py | 486 ------ .../python/megengine/xla/lib/__init__.py | 108 -- imperative/python/megengine/xla/lib/config.py | 1330 ----------------- .../python/megengine/xla/lib/mlir/__init__.py | 1 - .../xla/lib/mlir/dialects/__init__.py | 15 - .../python/megengine/xla/lib/xla_bridge.py | 503 ------- imperative/python/megengine/xla/lower.py | 260 ---- .../python/megengine/xla/rules/__init__.py | 13 - .../python/megengine/xla/rules/communicate.py | 82 - .../python/megengine/xla/rules/elemwise.py | 303 ---- .../python/megengine/xla/rules/hlotensor.py | 203 --- .../python/megengine/xla/rules/indexing.py | 696 --------- imperative/python/megengine/xla/rules/math.py | 238 --- imperative/python/megengine/xla/rules/nn.py | 681 --------- .../python/megengine/xla/rules/normalize.py | 185 --- .../python/megengine/xla/rules/random.py | 85 -- .../python/megengine/xla/rules/reduction.py | 175 --- .../python/megengine/xla/rules/tensor.py | 264 ---- .../python/megengine/xla/rules/trivial.py | 53 - .../python/megengine/xla/rules/utils.py | 98 -- imperative/python/megengine/xla/sharding.py | 454 ------ imperative/python/megengine/xla/utils.py | 95 -- 28 files changed, 7605 deletions(-) delete mode 100644 imperative/python/megengine/xla/__init__.py delete mode 100644 imperative/python/megengine/xla/build.py delete mode 100644 imperative/python/megengine/xla/compile.py delete mode 100644 imperative/python/megengine/xla/device.py delete mode 100644 imperative/python/megengine/xla/distribute.py delete mode 100644 imperative/python/megengine/xla/dtype.py delete mode 100644 imperative/python/megengine/xla/ir_utils.py delete mode 100644 imperative/python/megengine/xla/lib/__init__.py delete mode 100644 imperative/python/megengine/xla/lib/config.py delete mode 100644 imperative/python/megengine/xla/lib/mlir/__init__.py delete mode 100644 imperative/python/megengine/xla/lib/mlir/dialects/__init__.py delete mode 100644 imperative/python/megengine/xla/lib/xla_bridge.py delete mode 100644 imperative/python/megengine/xla/lower.py delete mode 100644 imperative/python/megengine/xla/rules/__init__.py delete mode 100644 imperative/python/megengine/xla/rules/communicate.py delete mode 100644 imperative/python/megengine/xla/rules/elemwise.py delete mode 100644 imperative/python/megengine/xla/rules/hlotensor.py delete mode 100644 imperative/python/megengine/xla/rules/indexing.py delete mode 100644 imperative/python/megengine/xla/rules/math.py delete mode 100644 imperative/python/megengine/xla/rules/nn.py delete mode 100644 imperative/python/megengine/xla/rules/normalize.py delete mode 100644 imperative/python/megengine/xla/rules/random.py delete mode 100644 imperative/python/megengine/xla/rules/reduction.py delete mode 100644 imperative/python/megengine/xla/rules/tensor.py delete mode 100644 imperative/python/megengine/xla/rules/trivial.py delete mode 100644 imperative/python/megengine/xla/rules/utils.py delete mode 100644 imperative/python/megengine/xla/sharding.py delete mode 100644 imperative/python/megengine/xla/utils.py diff --git a/imperative/python/megengine/xla/__init__.py b/imperative/python/megengine/xla/__init__.py deleted file mode 100644 index 5dcd05e12..000000000 --- a/imperative/python/megengine/xla/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .build import build_xla diff --git a/imperative/python/megengine/xla/build.py b/imperative/python/megengine/xla/build.py deleted file mode 100644 index 6db28a127..000000000 --- a/imperative/python/megengine/xla/build.py +++ /dev/null @@ -1,118 +0,0 @@ -import os - -from ..distributed import get_rank, get_world_size, is_distributed -from .compile import MeshComputation, PmapComputation -from .device import get_xla_backend_and_device -from .distribute import initialize -from .ir_utils import DropoutMaskCanonicalizer, RngKeyAdder, TraceResult -from .lib import xla_client as xc -from .lower import lower -from .sharding import OpShardingSharding, _is_unspecified, make_unspec_sharding - -xla_extention = xc._xla -xe = xla_extention - -Backend = xe.Client - - -def build_xla( - mge_traced, - func_name=None, - device=None, - keep_unused=True, - donate_invars=None, - verbose=int(os.environ.get("MGE_VERBOSE_XLA_IR", "0")), - return_with_io=False, - return_device_array=False, -): - assert device == None, "cannot specify device now" - assert keep_unused == True, "keep_unused error" - assert donate_invars == None, "donate_invars error" - - # normalize megengine trace result for lowering - tr = TraceResult(mge_traced, func_name) - tr = RngKeyAdder()(tr) - tr = DropoutMaskCanonicalizer()(tr) - - if verbose and get_rank() == 0: - print("================ Mge Trace Result ================") - print(tr) - - in_is_global = (True,) * len(tr.inputs) - kept_var_idx = set(range(len(tr.inputs))) if keep_unused else set() - - # init for xla distributed and setup device - if is_distributed(): - initialize("127.0.0.1:12345", get_world_size(), get_rank(), [get_rank()]) - backend, device_assignment, platform = get_xla_backend_and_device(device) - - module, keepalive, host_callbacks = lower( - tr, backend, platform, None, None, donate_invars, - ) - - if not is_distributed(): - # setup sharding information - in_shardings = make_unspec_sharding(tr.inputs) - out_shardings = make_unspec_sharding(tr.outputs) - - in_shardings = tuple( - OpShardingSharding.get_replicated(device_assignment) - if _is_unspecified(i) - else i - for i in in_shardings - ) - - computation = MeshComputation( - tr.func_name, - module, - donated_invars=donate_invars, - trace_result=tr, - mesh=None, - in_shardings=in_shardings, - out_shardings=out_shardings, - spmd_lowering=False, - tuple_args=False, # for tpu - in_is_global=in_is_global, - auto_spmd_lowering=False, - unordered_effects=[], - ordered_effects=[], - host_callbacks=host_callbacks, - keepalive=keepalive, - kept_var_idx=kept_var_idx, - backend=backend, - device_assignment=device_assignment, - committed=False, # unknown - pmap_nreps=1, - return_device_array=return_device_array, - ) - else: - computation = PmapComputation( - tr.func_name, - module, - trace_result=tr, - unordered_effects=[], - ordered_effects=[], - tuple_args=False, # for tpu - in_is_global=in_is_global, - host_callbacks=host_callbacks, - keepalive=keepalive, - kept_var_idx=kept_var_idx, - backend=backend, - devices=None, - return_device_array=return_device_array, - world_size=get_world_size(), - rank=get_rank(), - ) - - if verbose and get_rank() == 0: - print("================ XLA HLO IR ================") - print(computation.as_text()) - compiled = computation.compile() - if verbose and get_rank() == 0: - print("================ XLA Execute Plan ================") - print(compiled.as_text()) - - ret = compiled.unsafe_call - if return_with_io: - return ret, tr.inputs, tr.outputs - return ret diff --git a/imperative/python/megengine/xla/compile.py b/imperative/python/megengine/xla/compile.py deleted file mode 100644 index 1d74884ae..000000000 --- a/imperative/python/megengine/xla/compile.py +++ /dev/null @@ -1,872 +0,0 @@ -import dataclasses -import os -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Set, Union - -import jaxlib -import numpy as np - -from ..distributed import is_distributed -from ..utils.dlpack import from_dlpack, to_dlpack -from . import ir_utils -from .lib import xla_bridge as xb -from .lib import xla_client as xc -from .lib.mlir import ir -from .lib.mlir.dialects import use_stablehlo -from .sharding import ( - _get_normalized_avals_and_shardings, - _get_op_sharding_shardings_from_executable, - _get_pmap_sharding, - _is_unspecified, - _pmap_sharding_spec, - is_op_sharding_replicated, - pmap_lib, - shard_args, -) -from .utils import safe_zip, unzip2 - -xla_extension = xc._xla -xe = xla_extension - - -def compile_impl(backend, computation: ir.Module, compile_options, host_callbacks): - sym_name = computation.operation.attributes["sym_name"] - module_name = ir.StringAttr(sym_name).value - - serialized_computation: Union[str, bytes, ir.Module] - if getattr(backend, "needs_str_ir", True): - serialized_computation = ir_utils.module_to_bytecode(computation) - else: - serialized_computation = computation - - supported_platforms = ["gpu"] - if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""): - supported_platforms.append("cpu") - - def backend_compile(backend, built_c, options, host_callbacks): - if host_callbacks: - return backend.compile( - built_c, compile_options=options, host_callbacks=host_callbacks - ) - return backend.compile(built_c, compile_options=options) - - return backend_compile( - backend, serialized_computation, compile_options, host_callbacks - ) - - -class InputsHandler: - __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") - - def __init__(self, local_devices, in_shardings, input_indices): - self.handler = shard_args # partial(shard_args, local_devices, input_indices) - self.local_devices = local_devices - self.in_shardings = in_shardings - self.input_indices = input_indices - - def from_dlpack(self, dlpack): - return xe.dlpack_managed_tensor_to_buffer( - dlpack, None, self.local_devices[0].client - ) - - def __call__(self, input_buffers): - rst = [] - for idx, i in enumerate(input_buffers): - capsule = to_dlpack(i) - xla_array = self.from_dlpack(capsule) - - # if hasattr(i, "xla_array"): - # rst.append([i.xla_array]) - # else: - # r = self.handler(self.local_devices, [self.input_indices[idx],], [i,])[ - # 0 - # ] - # rst.append(r) - # i.xla_array = r[0] - rst.append([xla_array]) - return rst - - def __str__(self): - return ( - "InputsHandler(\n" - f"local_devices={self.local_devices},\n" - f"in_shardings={self.in_shardings},\n" - f"input_indices={self.input_indices})" - ) - - -class ResultsHandler: - __slots__ = ("handlers", "out_shardings", "out_avals", "return_device_array") - - def __init__( - self, - handlers=None, - out_shardings=None, - out_avals=None, - return_device_array=False, - ): - self.return_device_array = return_device_array - if handlers is None: - - def out_handler(bufs): - assert isinstance(bufs, list) and len(bufs) == 1 - assert isinstance(bufs[0], xe.ArrayImpl) - if not self.return_device_array: - return np.asarray(bufs[0]) - else: - return bufs[0] - - self.handlers = out_handler - self.out_shardings = out_shardings - self.out_avals = out_avals - - def __call__(self, out_bufs): - if isinstance(self.handlers, list): - return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)] - else: - return [self.handlers(bufs) for bufs in out_bufs] - - -class Executable(Protocol): - def call(self, *args_flat): - raise NotImplementedError - - def input_shardings(self): - raise NotImplementedError - - def output_shardings(self): - raise NotImplementedError - - def as_text(self) -> str: - raise NotImplementedError - - def cost_analysis(self) -> Any: - raise NotImplementedError - - def memory_analysis(self) -> Any: - raise NotImplementedError - - def runtime_executable(self) -> Any: - raise NotImplementedError - - def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: - return None - - -class XlaExecutable(Executable): - def xla_extension_executable(self): - raise NotImplementedError("should be overrided") - - def call(self, *args_flat): - raise NotImplementedError("should be overrided") - - def input_shardings(self): - raise NotImplementedError( - "should be overrided" - ) - - def output_shardings(self): - raise NotImplementedError( - "should be overrided" - ) - - def as_text(self) -> str: - xla_ext_exe = self.xla_extension_executable() - err_msg = ( - "text view unsupported on current XLA backend: " f"{type(xla_ext_exe)}" - ) - if not hasattr(xla_ext_exe, "hlo_modules"): - raise NotImplementedError(err_msg) - try: - return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - - def cost_analysis(self) -> List[Dict[str, float]]: - xla_ext_exe = self.xla_extension_executable() - err_msg = ( - "cost analysis unsupported on current XLA backend: " f"{type(xla_ext_exe)}" - ) - # TODO: Unify/merge the two cost_analysis calls below. - if hasattr(xla_ext_exe, "client"): - try: - return [ - xla_extension.hlo_module_cost_analysis(xla_ext_exe.client, m) - for m in xla_ext_exe.hlo_modules() - ] - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - elif hasattr(xla_ext_exe, "cost_analysis"): - try: - return xla_ext_exe.cost_analysis() - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - else: - raise NotImplementedError(err_msg) - - def memory_analysis(self) -> Any: - xla_ext_exe = self.xla_extension_executable() - err_msg = ( - "memory analysis unsupported on current XLA backend: " - f"{type(xla_ext_exe)}" - ) - if not hasattr(xla_ext_exe, "get_compiled_memory_stats"): - raise NotImplementedError(err_msg) - try: - return xla_ext_exe.get_compiled_memory_stats() - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - - def runtime_executable(self) -> Any: - return self.xla_extension_executable() - - -# The logic to shard inputs, execute a replicated model, returning outputs -class ExecuteReplicated: - __slots__ = [ - "xla_executable", - "name", - "backend", - "in_handler", - "out_handler", - "has_unordered_effects", - "ordered_effects", - "keepalive", - "has_host_callbacks", - "_local_devices", - "kept_var_idx", - "__weakref__", - ] - - def __init__( - self, - xla_executable, - name, - backend, - in_handler: InputsHandler, - out_handler: ResultsHandler, - unordered_effects: Any, - ordered_effects: Any, - keepalive: Any, - has_host_callbacks: bool, - kept_var_idx: Set[int], - ): - self.xla_executable = xla_executable - self.name = name - self.backend = backend - self.in_handler = in_handler - self.out_handler = out_handler - self.has_unordered_effects = bool(unordered_effects) - self.ordered_effects = ordered_effects - self._local_devices = self.xla_executable.local_devices() - if ordered_effects: - assert len(self._local_devices) == 1 - self.keepalive = keepalive - self.has_host_callbacks = has_host_callbacks - self.kept_var_idx = kept_var_idx - - def __call__(self, *args): - args = [x for i, x in enumerate(args) if i in self.kept_var_idx] - input_bufs = self.in_handler(args) - assert not ( - self.ordered_effects - or self.has_unordered_effects - or self.has_host_callbacks - ) - - if True or not is_distributed(): - out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs) - return self.out_handler(out_bufs) - else: - results = self.xla_executable.execute_sharded(input_bufs) - outputs = results.disassemble_into_single_device_arrays() - assert isinstance(outputs, list) - out_bufs = [] - for oup in outputs: - assert isinstance(oup, list) and len(oup) == 1 - out_bufs.append(oup[0].device_buffers) - return self.out_handler(out_bufs) - - -@dataclasses.dataclass -class UnloadedMeshExecutable: - xla_executable: Any - trace_result: ir_utils.TraceResult - device_assignment: Sequence[xc.Device] - backend: xb.XlaBackend - input_shardings: Sequence[Any] - output_shardings: Sequence[Any] - committed: bool - are_out_shardings_from_xla: Sequence[bool] - pmap_nreps: int - name: str - unordered_effects: List[Any] - ordered_effects: List[Any] - keepalive: Sequence[Any] - host_callbacks: Sequence[Any] - kept_var_idx: Set[int] - auto_spmd_lowering: bool - return_device_array: bool = False - - def load(self): - def _get_input_indices(avals, shardings): - input_indices = [] - for aval, sharding in zip(avals, shardings): - proto = sharding._to_xla_op_sharding(len(aval.shape)) - if is_op_sharding_replicated(proto): - index = tuple( - (slice(None),) * len(aval.shape) - for _ in range(len(sharding.addressable_devices)) - ) - else: - assert False - input_indices.append(index) - return input_indices - - input_indices = _get_input_indices( - self.trace_result._var_inputs, self.input_shardings - ) - handle_inps = InputsHandler( - self.xla_executable.local_devices(), self.input_shardings, input_indices - ) - handle_oups = ResultsHandler(return_device_array=self.return_device_array) - - if self.pmap_nreps > 1: - assert False - else: - unsafe_call = ExecuteReplicated( - self.xla_executable, - self.name, - self.backend, - handle_inps, - handle_oups, - self.unordered_effects, - self.ordered_effects, - self.keepalive, - bool(self.host_callbacks), - self.kept_var_idx, - ) - - return MeshExecutable( - self.xla_executable, - unsafe_call, - self.trace_result, - self.input_shardings, - self.output_shardings, - self.auto_spmd_lowering, - self.kept_var_idx, - self.device_assignment, - ) - - @staticmethod - def from_hlo( - name: str, - computation, - mesh, - trace_result: ir_utils.TraceResult, - in_shardings, - out_shardings, - spmd_lowering: bool, - tuple_args: bool, - in_is_global: Sequence[bool], - auto_spmd_lowering: bool, - _allow_propagation_to_outputs: bool, - _allow_compile_replicated: bool, - unordered_effects, - ordered_effects, - host_callbacks, - keepalive, - kept_var_idx, - backend: xb.XlaBackend, - device_assignment: Sequence[xc.Device], - committed: bool, - pmap_nreps: int = 1, - return_device_array: bool = False, - ): - assert mesh == None - assert spmd_lowering == False - assert tuple_args == False - assert in_is_global == (True,) * len(trace_result.inputs) - assert auto_spmd_lowering == False - assert _allow_propagation_to_outputs == False - assert _allow_compile_replicated == True - assert unordered_effects == [] - assert ordered_effects == [] - assert host_callbacks == [] - assert keepalive == [] - assert committed == False - assert pmap_nreps == 1 - - dev: np.ndarray - if auto_spmd_lowering: - assert mesh is not None and spmd_lowering - dev = mesh.devices - num_replicas, num_partitions = 1, mesh.size - else: - dev = np.array(device_assignment) - if pmap_nreps > 1: - num_replicas, num_partitions = pmap_nreps, 1 - elif spmd_lowering: - num_replicas, num_partitions = 1, dev.size - else: - num_replicas, num_partitions = dev.size, 1 - - if pmap_nreps > 1: - xla_device_assignment = None - else: - xla_device_assignment = dev.reshape((num_replicas, num_partitions)) - - assert num_replicas == 1 and num_partitions == 1 - compile_options = xb.get_compile_options( - num_replicas=num_replicas, - num_partitions=num_partitions, - device_assignment=xla_device_assignment, - use_spmd_partitioning=spmd_lowering, - use_auto_spmd_partitioning=auto_spmd_lowering, - ) - if auto_spmd_lowering: - assert False - # tuple_args is only tpu related, so in mge we close it - compile_options.parameter_is_tupled_arguments = False - allow_propagation = [_allow_propagation_to_outputs] - compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = ( - allow_propagation - ) - assert hasattr(backend, "compile_replicated") == False - if _allow_compile_replicated and hasattr(backend, "compile_replicated"): - assert False - else: - xla_executable = compile_impl( - backend, computation, compile_options, host_callbacks - ) - - if auto_spmd_lowering: - assert False - elif out_shardings and any(_is_unspecified(o) for o in out_shardings): - assert mesh is None - _, out_shardings_xla = _get_op_sharding_shardings_from_executable( # type: ignore - xla_executable, - device_assignment, - len(trace_result.inputs), - len(trace_result.outputs), - ) - out_shardings_tuple = [ - (x, True) if _is_unspecified(o) else (o, False) - for x, o in safe_zip(out_shardings_xla, out_shardings) - ] - out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) - else: - are_out_shardings_from_xla = (False,) * len(trace_result.outputs) - - input_avals, input_shardings = _get_normalized_avals_and_shardings( - trace_result._var_inputs, in_shardings, in_is_global - ) - - return UnloadedMeshExecutable( - xla_executable=xla_executable, - trace_result=trace_result, - device_assignment=device_assignment, - backend=backend, - input_shardings=input_shardings, - output_shardings=out_shardings, - committed=committed, - are_out_shardings_from_xla=are_out_shardings_from_xla, - pmap_nreps=pmap_nreps, - name=name, - unordered_effects=unordered_effects, - ordered_effects=ordered_effects, - keepalive=keepalive, - host_callbacks=host_callbacks, - kept_var_idx=kept_var_idx, - auto_spmd_lowering=auto_spmd_lowering, - return_device_array=return_device_array, - ) - - -class MeshExecutable(XlaExecutable): - __slots__ = [ - "xla_executable", - "unsafe_call", - "trace_result", - "_in_shardings", - "_out_shardings", - "_auto_spmd_lowering", - "_kept_var_idx", - "_device_assignment", - ] - - def __init__( - self, - xla_executable, - unsafe_call, - trace_result, - in_shardings, - out_shardings, - auto_spmd_lowering, - kept_var_idx, - device_assignment, - ): - self.xla_executable = xla_executable - self.unsafe_call = unsafe_call - self.trace_result = trace_result - self._in_shardings = in_shardings - self._out_shardings = out_shardings - self._auto_spmd_lowering = auto_spmd_lowering - self._kept_var_idx = kept_var_idx - self._device_assignment = device_assignment - - def xla_extension_executable(self): - return self.xla_executable - - def call(self, *args): - return self.unsafe_call(*args) - - def input_shardings(self): - return self._in_shardings - - def output_shardings(self): - return self._out_shardings - - -class Lowering(Protocol): - def compile(self) -> Executable: - raise NotImplementedError - - def as_text(self, dialect: Optional[str] = None) -> str: - raise NotImplementedError - - def compiler_ir(self, dialect: Optional[str] = None) -> Any: - raise NotImplementedError - - -class XlaLowering(Lowering): - def hlo(self) -> xc.XlaComputation: - raise NotImplementedError("must override") - - # Return an MHLO IR of computation - def mhlo(self) -> ir.Module: - if use_stablehlo: - module_str = xla_extension.mlir.stablehlo_to_mhlo( - ir_utils.module_to_bytecode(self.stablehlo()) - ) - with self.stablehlo().context: - return ir.Module.parse(module_str) - else: - raise NotImplementedError("must override") - - # Return a StableHLO IR of computation - def stablehlo(self) -> ir.Module: - if use_stablehlo: - raise NotImplementedError("must override") - else: - module_str = xla_extension.mlir.mhlo_to_stablehlo( - ir_utils.module_to_bytecode(self.mhlo()) - ) - with self.mhlo().context: - return ir.Module.parse(module_str) - - def compile(self) -> Executable: - raise NotImplementedError("must override") - - def as_text(self, dialect: Optional[str] = None) -> str: - if dialect is None: - dialect = "stablehlo" if use_stablehlo else "mhlo" - if dialect == "mhlo": - return str(self.mhlo()) - elif dialect == "stablehlo": - return str(self.stablehlo()) - elif dialect == "hlo": - return self.hlo().as_hlo_text() - else: - raise ValueError(f"unknown dialect: {dialect}") - - def compiler_ir(self, dialect: Optional[str] = None) -> Any: - if dialect is None: - dialect = "stablehlo" if use_stablehlo else "mhlo" - if dialect == "mhlo": - return self.mhlo() - elif dialect == "stablehlo": - return self.stablehlo() - elif dialect == "hlo": - return self.hlo() - else: - raise ValueError(f"unknown dialect: {dialect}") - - -class MeshComputation(XlaLowering): - _hlo: Optional[ir.Module] - _executable: Optional[MeshExecutable] - - def __init__( - self, - name: str, - hlo: Optional[ir.Module], - donated_invars: Sequence[bool], - **compile_args - ): - self._name = name - self._hlo = hlo - self._donated_invars = donated_invars - self.compile_args = compile_args - self._executable = None - - def _compile_unloaded( - self, - _allow_propagation_to_outputs: bool = False, - _allow_compile_replicated: bool = True, - ) -> Union[UnloadedMeshExecutable, MeshExecutable]: - return UnloadedMeshExecutable.from_hlo( - self._name, - self._hlo, - **self.compile_args, - _allow_propagation_to_outputs=_allow_propagation_to_outputs, - _allow_compile_replicated=_allow_compile_replicated, - ) - - def hlo(self) -> xc.XlaComputation: - return xe.mlir.mlir_module_to_xla_computation( - ir_utils.module_to_string(self._hlo), - use_tuple_args=self.compile_args["tuple_args"], - ) - - def mhlo(self) -> ir.Module: - if use_stablehlo: - return super().mhlo() - else: - return self._hlo - - def stablehlo(self) -> ir.Module: - if use_stablehlo: - return self._hlo - else: - return super().stablehlo() - - def compile( - self, - _allow_propagation_to_outputs: bool = False, - _allow_compile_replicated: bool = True, - ) -> MeshExecutable: - if self._executable is None: - executable = self._compile_unloaded( - _allow_propagation_to_outputs, _allow_compile_replicated - ) - if isinstance(executable, UnloadedMeshExecutable): - executable = executable.load() - self._executable = executable - return self._executable - - -class PmapExecutable(XlaExecutable): - __slots__ = [ - "xla_executable", - "_unsafe_call", - "build_unsafe_call", - "trace_result", - "_unloaded_executable", - ] - - def __init__( - self, - xla_executable, - build_unsafe_call, - trace_result, - unloaded_executable, - ): - self.xla_executable = xla_executable - self._unsafe_call = None - self.build_unsafe_call = build_unsafe_call - self.trace_result = trace_result - self._unloaded_executable = unloaded_executable - - @property - def unsafe_call(self) -> Callable[..., Any]: - if self._unsafe_call is None: - self._unsafe_call = self.build_unsafe_call() - return self._unsafe_call - - def xla_extension_executable(self): - return self.xla_executable - - def call(self, *args): - return self.unsafe_call(*args) - - -@dataclasses.dataclass -class UnloadedPmapExecutable: - compiled: Any - trace_result: ir_utils.TraceResult - backend: xb.XlaBackend - input_shardings: Sequence[Any] - output_shardings: Sequence[Any] - unordered_effects: List[Any] - ordered_effects: List[Any] - keepalive: Sequence[Any] - host_callbacks: Sequence[Any] - kept_var_idx: Set[int] - rank: int - return_device_array: bool = False - - @staticmethod - def from_hlo( - computation, - trace_result: ir_utils.TraceResult, - unordered_effects, - ordered_effects, - tuple_args, # for tpu - in_is_global, - host_callbacks, - keepalive, - kept_var_idx, - backend, - devices, - return_device_array, - world_size, - rank, - ): - assert unordered_effects == [] - assert ordered_effects == [] - assert host_callbacks == [] - assert keepalive == [] - assert tuple_args == False - assert in_is_global == (True,) * len(trace_result.inputs) - assert devices is None - - if devices is None: - if world_size > xb.device_count(backend): - assert ( - False - ), f"world_size={world_size} is bigger than device_count={xb.device_count(backend)}" - - devices = [ - d - for process_index in range(xb.process_count(backend)) - for d in xb.local_devices(process_index, backend) - ] - else: - assert False, "impossible" - - device_assignment: np.ndarray = np.array(devices).reshape((world_size, 1)) - - use_spmd_partitioning = False - compile_options = xb.get_compile_options( - num_replicas=world_size, - num_partitions=1, - device_assignment=device_assignment, - use_spmd_partitioning=use_spmd_partitioning, - ) - compile_options.parameter_is_tupled_arguments = tuple_args - compiled = compile_impl(backend, computation, compile_options, host_callbacks) - - process_index = xb.process_index(backend) - local_device_assignment = np.array( - [d for d in device_assignment.flat if d.process_index == process_index] - ) - - ishapes = [inp.shape for inp in trace_result._var_inputs] - input_sharding_specs = [ - _pmap_sharding_spec(1, 1, 1, None, ishape, 0) for ishape in ishapes - ] - in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs) - - oshapes = [out.shape for out in trace_result._var_outputs] - out_specs = [ - _pmap_sharding_spec(1, 1, 1, None, oshape, 0) for oshape in oshapes - ] - out_shardings = _get_pmap_sharding(local_device_assignment, out_specs) - - return UnloadedPmapExecutable( - compiled=compiled, - trace_result=trace_result, - backend=backend, - input_shardings=in_shardings, - output_shardings=out_shardings, - unordered_effects=unordered_effects, - ordered_effects=ordered_effects, - keepalive=keepalive, - host_callbacks=host_callbacks, - kept_var_idx=kept_var_idx, - rank=rank, - return_device_array=return_device_array, - ).load() - - def build_execute_fun(self): - input_indices = [] - ishapes = [inp.shape for inp in self.trace_result._var_inputs] - for ishape, isharding in safe_zip(ishapes, self.input_shardings): - spec = isharding.sharding_spec - assert len(spec.sharding) == len(ishape) + 1 - assert spec.sharding[0] == pmap_lib.Unstacked(1) - assert all(isinstance(s, pmap_lib.NoSharding) for s in spec.sharding[1:]) - input_indices.append( - ((tuple(slice(None, None, None) for _ in range(len(ishape)))),) - ) - handle_inps = InputsHandler( - self.compiled.local_devices(), self.input_shardings, input_indices - ) - handle_oups = ResultsHandler(return_device_array=self.return_device_array) - - execute_fun = ExecuteReplicated( - self.compiled, - "parallel computation", - self.backend, - handle_inps, - handle_oups, - self.unordered_effects, - self.ordered_effects, - self.keepalive, - bool(self.host_callbacks), - set(range(len(input_indices))), - ) - return execute_fun - - def load(self) -> PmapExecutable: - return PmapExecutable( - self.compiled, self.build_execute_fun, self.trace_result, self, - ) - - -class PmapComputation(XlaLowering): - _name: str - _hlo: ir.Module - _executable: Optional[PmapExecutable] - - def __init__(self, name, hlo: ir.Module, **compile_args): - self._name = name - self._executable = None - self._hlo = hlo - self.compile_args = compile_args - - def hlo(self) -> xc.XlaComputation: - return xe.mlir.mlir_module_to_xla_computation( - ir_utils.module_to_string(self._hlo), - use_tuple_args=self.compile_args["tuple_args"], - ) - - def mhlo(self) -> ir.Module: - return super().mhlo() - - def stablehlo(self) -> ir.Module: - return self._hlo - - def compile(self) -> PmapExecutable: - if self._executable is None: - self._executable = UnloadedPmapExecutable.from_hlo( - self._hlo, **self.compile_args - ) - return self._executable diff --git a/imperative/python/megengine/xla/device.py b/imperative/python/megengine/xla/device.py deleted file mode 100644 index bbb68462c..000000000 --- a/imperative/python/megengine/xla/device.py +++ /dev/null @@ -1,86 +0,0 @@ -import itertools as it -from typing import Sequence, Tuple, Union - -import numpy as np - -from ..core._imperative_rt.common import CompNode -from ..tensor import Parameter as MgeParameter -from ..tensor import Tensor as MgeTensor -from .dtype import ( - _np_types, - _python_scalar_dtypes, - _scalar_type_to_dtype, - canonicalize_arg, -) -from .lib import xla_bridge as xb -from .lib import xla_client as xc -from .utils import safe_zip - -xla_extention = xc._xla -xe = xla_extention - -Backend = xe.Client - -device_put_handlers = {} - - -def _device_put_nparray(x, device): - backend = xb.get_device_backend(device) - return (backend.buffer_from_pyval(x, device),) - - -def _device_put_scalar(x, device): - def cvt_scalar_to_nparray(x, dtype=None): - if dtype is None and type(x) in _python_scalar_dtypes: - dtype = _scalar_type_to_dtype(type(x), x) - return np.asarray(x, dtype) - - return _device_put_nparray(cvt_scalar_to_nparray(x), device) - - -def _device_put_device_array(x, device): - assert False - - -def _device_put_mge_tensor(x, device): - x = x.numpy() - return _device_put_nparray(x, device) - - -for nt in _np_types: - device_put_handlers[nt] = _device_put_nparray -for sc in _python_scalar_dtypes: - device_put_handlers[nt] = _device_put_scalar -device_put_handlers[xc._xla.DeviceArray] = _device_put_device_array -device_put_handlers[MgeTensor] = _device_put_mge_tensor -device_put_handlers[MgeParameter] = _device_put_mge_tensor - - -def _device_put_impl(x, device): - x = canonicalize_arg(x) - return device_put_handlers[type(x)](x, device) - - -def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool = False): - if replicate: - return list( - it.chain.from_iterable(_device_put_impl(x, device) for device in devices) - ) - else: - return list( - it.chain.from_iterable( - _device_put_impl(val, device) for val, device in safe_zip(x, devices) - ) - ) - - -def get_xla_backend_and_device(device=None) -> Tuple[Backend, Sequence[xc.Device]]: - assert device is None, "device assignment is not supported yet" - device_assignment = [xb.local_devices()[0]] - backend = xb.get_device_backend(device_assignment[0]) - platform = backend.platform - platform = xb.canonicalize_platform(platform) - - assert xb.is_known_platform(platform), f"{platform} is not known yet" - assert platform == "cuda", f"only cuda platfrom is supportted, but get {platform}" - return backend, device_assignment, platform diff --git a/imperative/python/megengine/xla/distribute.py b/imperative/python/megengine/xla/distribute.py deleted file mode 100644 index ea439a038..000000000 --- a/imperative/python/megengine/xla/distribute.py +++ /dev/null @@ -1,91 +0,0 @@ -import atexit -from typing import Any, Optional, Sequence, Union - -from .lib import xla_client as xc - -xla_extention = xc._xla -xe = xla_extention - - -class State: - process_id: int = 0 - service: Optional[Any] = None - client: Optional[Any] = None - preemption_sync_manager: Optional[Any] = None - visible_devices: Optional[str] = "all" - - def initialize( - self, - coordinator_address: str, - num_processes: int, - process_id: int, - local_device_ids: Optional[Union[int, Sequence[int]]] = None, - ): - if local_device_ids is None: - local_device_ids = [process_id] - elif isinstance(local_device_ids, int): - local_device_ids = [local_device_ids] - else: - local_device_ids = list(local_device_ids) - - assert local_device_ids == [process_id], f"{local_device_ids} .vs {process_id}" - - self.visible_devices = ",".join(str(x) for x in local_device_ids) - self.process_id = process_id - - if process_id == 0: - if self.service is not None: - raise RuntimeError("distributed.initialize should only be called once.") - self.service = xe.get_distributed_runtime_service( - coordinator_address, num_processes, use_coordination_service=True - ) - - if self.client is not None: - raise RuntimeError("distributed.initialize should only be called once.") - - # Set init_timeout to 5 min to leave time for all the processes to connect - self.client = xe.get_distributed_runtime_client( - coordinator_address, - process_id, - use_coordination_service=True, - init_timeout=300, - ) - self.client.connect() - self.initialize_preemption_sync_manager() - - def shutdown(self): - if self.client: - self.client.shutdown() - self.client = None - if self.service: - self.service.shutdown() - self.service = None - if self.preemption_sync_manager: - self.preemption_sync_manager = None - - def initialize_preemption_sync_manager(self): - if self.preemption_sync_manager is not None: - raise RuntimeError( - "Preemption sync manager should only be initialized once." - ) - self.preemption_sync_manager = xe.create_preemption_sync_manager() - self.preemption_sync_manager.initialize(self.client) - - -global_state = State() - - -def initialize( - coordinator_address: str, - num_processes: int, - process_id: int, - local_device_ids: Optional[Union[int, Sequence[int]]] = None, -): - global_state.initialize( - coordinator_address, num_processes, process_id, local_device_ids - ) - atexit.register(shutdown) - - -def shutdown(): - global_state.shutdown() diff --git a/imperative/python/megengine/xla/dtype.py b/imperative/python/megengine/xla/dtype.py deleted file mode 100644 index 724470035..000000000 --- a/imperative/python/megengine/xla/dtype.py +++ /dev/null @@ -1,109 +0,0 @@ -from functools import lru_cache, partial - -import numpy as np - -from ..tensor import Parameter as MgeParameter -from ..tensor import Tensor as MgeTensor -from .lib import xla_client as xc - -_python_scalar_dtype_to_npdtypes = { - bool: np.dtype("bool"), - int: np.dtype("int64"), - float: np.dtype("float64"), - complex: np.dtype('complex128'), -} - -_python_scalar_dtypes = list(_python_scalar_dtype_to_npdtypes.keys()) - -bfloat16 = xc.bfloat16 -_bfloat16_dtype = np.dtype(bfloat16) -_float_types = [ - _bfloat16_dtype, - np.dtype("float16"), - np.dtype("float32"), - np.dtype("float64"), -] - -_numpy_scalar_types = { - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.complex64, - np.complex128, - np.bool_, - np.longlong, - np.intc, -} | set(np.dtype(dt).type for dt in _float_types) - -_np_types = {np.ndarray} | _numpy_scalar_types - -_dtype_to_32bit_dtype = { - np.dtype("int64"): np.dtype("int32"), - np.dtype("uint64"): np.dtype("uint32"), - np.dtype("float64"): np.dtype("float32"), - np.dtype('complex128'): np.dtype('complex64'), -} - - -def _scalar_type_to_dtype(typ, value): - dtype = canonicalize_dtype(_python_scalar_dtype_to_npdtypes[typ]) - if typ is int and value is not None: - if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max: - raise OverflowError(f"Python int {value} too large to convert to {dtype}") - return dtype - - -# do not enable x64 because megengine only support x32 -@lru_cache(maxsize=None) -def canonicalize_dtype(dtype, x64_enabled=False, allow_opaque_dtype=False): - assert allow_opaque_dtype == False and x64_enabled == False - try: - dtype_ = np.dtype(dtype) - except TypeError as e: - raise TypeError(f"dtype {dtype!r} not understood") from e - - if x64_enabled: - return dtype_ - else: - return _dtype_to_32bit_dtype.get(dtype_, dtype_) - - -def _canonicalize_ndarray_dtype(x): - return np.asarray(x, canonicalize_dtype(x.dtype)) - - -def _canonicalize_python_scalar_dtype(typ, x): - return np.asarray(x, canonicalize_dtype(_scalar_type_to_dtype(typ, x))) - - -def _canonicalize_mgetensor_dtype(x: MgeTensor): - canonicalized = canonicalize_dtype(x.dtype) - if canonicalized != x.dtype: - return x.astype(canonicalized) - return x - - -canonicalize_args_handlers = {} - -canonicalize_args_handlers.update( - (t, _canonicalize_ndarray_dtype) for t in _numpy_scalar_types -) -canonicalize_args_handlers[np.ndarray] = _canonicalize_ndarray_dtype -canonicalize_args_handlers.update( - (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _python_scalar_dtypes -) -canonicalize_args_handlers[MgeTensor] = _canonicalize_mgetensor_dtype -canonicalize_args_handlers[MgeParameter] = _canonicalize_mgetensor_dtype - - -def canonicalize_arg(x): - typ = type(x) - handler = canonicalize_args_handlers.get(typ) - if handler: - return handler(x) - raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") diff --git a/imperative/python/megengine/xla/ir_utils.py b/imperative/python/megengine/xla/ir_utils.py deleted file mode 100644 index 385407f77..000000000 --- a/imperative/python/megengine/xla/ir_utils.py +++ /dev/null @@ -1,486 +0,0 @@ -import io -from abc import ABC, abstractmethod -from functools import partial -from typing import Any, Callable, Dict, Sequence, Tuple - -import numpy as np - -from ..core._imperative_rt import ops as mops -from ..core._imperative_rt.core2 import OpInfo, VarInfo -from . import dtype -from .lib.mlir import ir -from .lib.mlir.dialects import hlo - -func_id = 0 - - -def _default_func_name(): - global func_id - func_id += 1 - return f"please_realize_func_name_system_{func_id}" - - -def _is_rng_op(opr): - return isinstance( - opr, - ( - mops.Dropout, - mops.BetaRNG, - mops.GammaRNG, - mops.GaussianRNG, - mops.PermutationRNG, - mops.PoissonRNG, - mops.ShuffleRNG, - mops.UniformRNG, - ), - ) - - -class AbstractVar: - def __init__(self, _id, _shape, _dtype) -> None: - self.id = _id - self.shape = _shape - self.dtype = _dtype - self.bound_data = None - - -class Pass(ABC): - def __init__(self) -> None: - pass - - @abstractmethod - def __call__(self, tr) -> Any: - pass - - -# because xla pass key as a tensor, while mge pass key as a param, so we need to add a -# rng key tensor to the graph and set it as the input of the graph and rng op -class RngKeyAdder(Pass): - def __call__(self, tr) -> Any: - has_rng_opr = False - - for eqn in tr.eqns: - if _is_rng_op(eqn.op): - has_rng_opr = True - break - - if not has_rng_opr: - return tr - - # it should be [2, np.uint64], however, megengine donot support np.uint64/np.int64/np.uint32 - inp_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32)) - tr.add_input(inp_rng_state_var) - - new_eqns = [] - for eqn in tr.eqns: - if not _is_rng_op(eqn.op): - new_eqns.append(eqn) - continue - - oup_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32)) - tr.add_var(oup_rng_state_var) - - inputs, outputs = list(eqn.inputs), list(eqn.outputs) - inputs.append(inp_rng_state_var.id) - outputs.append(oup_rng_state_var.id) - new_eqn = OpInfo(eqn.op, inputs, outputs, eqn.id, eqn.kind) - new_eqns.append(new_eqn) - inp_rng_state_var = oup_rng_state_var - - tr.eqns = new_eqns - tr.set_var_as_oup(inp_rng_state_var) - - return tr - - -# in megengine, dropout return a bit-mask while xla hard to represent, so we let xla -# return a uint8 mask, which means the mask is 8 times larger than mge -class DropoutMaskCanonicalizer(Pass): - def __call__(self, tr) -> Any: - for eqn in tr.eqns: - if not isinstance(eqn.op, mops.Dropout): - continue - - outputs = list(eqn.outputs) - mask_var = tr.vars[outputs[1]] - new_mask_var = AbstractVar( - mask_var.id, (int(np.prod(mask_var.shape)) * 8,), mask_var.dtype - ) - tr.vars[mask_var.id] = new_mask_var - - return tr - - -class TraceResult: - def __init__(self, traced, func_name=None) -> None: - self.func_name = func_name if func_name is not None else _default_func_name() - self.traced = traced - self.eqns = [] - self.vars = {} - self.inputs = [] - self.outputs = [] - self.consts = [] - self.custom_vid = 0 - - self.effects = [] - - for var in self.traced.vars: - self.add_var(var) - self.custom_vid = max(var.id + 1, self.custom_vid) - - if var.kind == "external" and var.inp_mark: - self.inputs.append(var.id) - - if var.data_required: - self.outputs.append(var.id) - - if var.kind == "const": - self.consts.append(var.id) - - for op in self.traced.ops: - self.eqns.append(op) - - @property - def _var_inputs(self): - return [self.vars[i] for i in self.inputs] - - @property - def _var_outputs(self): - return [self.vars[i] for i in self.outputs] - - @property - def _var_consts(self): - return [self.vars[i] for i in self.consts] - - @property - def next_vid(self): - ret = self.custom_vid - self.custom_vid += 1 - return ret - - def add_var(self, var): - assert var.id not in self.vars - self.vars[var.id] = var - - def add_input(self, inp_var): - self.add_var(inp_var) - self.inputs.append(inp_var.id) - - def set_var_as_oup(self, oup_var): - assert oup_var.id in self.vars - self.outputs.append(oup_var.id) - - def get_var(self, idx): - assert isinstance(idx, int) - return self.vars[idx] - - def is_input(self, var): - if isinstance(var, int): - var = self.vars[var] - return var.kind == "external" - - def is_output(self, var): - if isinstance(var, int): - var = self.vars[var] - return var.data_required - - def _str_var(self, var): - def _str_shape(shp): - return "x".join([str(d) for d in shp]) - - dtype_to_str = { - "float16": "f16", - "float32": "f32", - "int32": "i32", - "int64": "i64", - "uint8": "u8", - "uint32": "u32", - "uint64": "u64", - "bool": "i1-bool", - } - - if isinstance(var, int): - var = self.vars[var] - var_dtype = None - try: - var_dtype = dtype_to_str[str(var.dtype)] - except RuntimeError: - var_dtype = "unknown" - - var_bound_data = ( - ("," + ",".join(str(var.bound_data).split())) - if var.bound_data is not None and var.bound_data.size < 5 - else "" - ) - - return f"{var.id}%:<{_str_shape(var.shape)},{var_dtype}{var_bound_data}>" - - def _str_eqn(self, eqn): - inps = ", ".join(map(self._str_var, eqn.inputs)) - oups = ", ".join(map(self._str_var, eqn.outputs)) - str_op = str(eqn.op) - if isinstance(eqn.op, mops.Reduce): - assert str(eqn.op.mode).startswith("Reduce.Mode.") - str_op = str_op + str(eqn.op.mode)[len("Reduce.Mode.") :] - ret = f"{oups} = {str_op}({inps})" - return ret - - def __str__(self) -> str: - func_inps_str = ", ".join(map(self._str_var, self.inputs)) - func_oups_str = ", ".join(map(self._str_var, self.outputs)) - func_const_str = "\n ".join(map(self._str_var, self.consts)) - ret = f"{self.func_name}({func_inps_str}) -> ({func_oups_str}) {{\n " - if len(self.consts) > 0: - ret += f"const:\n {func_const_str}\n " - ret += "\n ".join(map(self._str_eqn, self.eqns)) - ret += "\n}" - return ret - - -_dtype_to_ir_type: Dict[np.dtype, Callable[[], ir.Type]] = { - np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), - np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8), - np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16), - np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32), - np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64), - np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8), - np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16), - np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32), - np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64), - np.dtype(dtype.bfloat16): ir.BF16Type.get, - np.dtype(np.float16): ir.F16Type.get, - np.dtype(np.float32): ir.F32Type.get, - np.dtype(np.float64): ir.F64Type.get, - np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), - np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), -} - - -def mge_dtype_to_ir_type(mge_dtype): - mge_dtype = np.dtype(mge_dtype) - assert isinstance( - mge_dtype, np.dtype - ), f"arg should be numpy dtype, but is {mge_dtype}" - ir_type_factory = _dtype_to_ir_type[mge_dtype] - return ir_type_factory() - - -def mge_varinfo_to_ir_type(mge_varinfo): - assert isinstance(mge_varinfo, (VarInfo, AbstractVar)), "args should be VarInfo" - shape = mge_varinfo.shape - return ir.RankedTensorType.get(shape, mge_dtype_to_ir_type(mge_varinfo.dtype)) - - -def mge_varinfo_to_ir_type_tuple(mge_varinfo): - return (mge_varinfo_to_ir_type(mge_varinfo),) - - -def make_ir_type_according_meta(src_shape: Tuple, src_dtype: np.dtype): - return ir.RankedTensorType.get(src_shape, mge_dtype_to_ir_type(src_dtype)) - - -def make_ir_type_according_meta_tuple(src_shape: Tuple, src_dtype: np.dtype): - return (make_ir_type_according_meta(src_shape, src_dtype),) - - -_constant_handlers = {} - - -def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: - if canonicalize_types: - x = np.asarray(x, dtype.canonicalize_dtype(x.dtype)) - element_type = mge_dtype_to_ir_type(x.dtype) - shape = x.shape - if x.dtype == np.bool_: - nelems = x.size - x = np.packbits(x, bitorder="little") - if nelems == 1: - x = np.array(0 if x.item() == 0 else 0xFF, np.uint8) - elif x.dtype == dtype.bfloat16: - x = x.view(np.uint16) - x = np.ascontiguousarray(x) - attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) - return (hlo.ConstantOp(attr).result,) - - -def _ndarray_constant_handler( - val: np.ndarray, canonicalize_types -) -> Sequence[ir.Value]: - if np.any(np.equal(0, val.strides)) and val.size > 0: - (zero_stride_axes,) = np.where(np.equal(0, val.strides)) - (other_axes,) = np.where(np.not_equal(0, val.strides)) - collapsed_val = val[ - tuple( - 0 if ax in zero_stride_axes else slice(None) - for ax in range(val.ndim) - ) - ] - if canonicalize_types: - collapsed_val = np.asarray( - collapsed_val, dtype.canonicalize_dtype(collapsed_val.dtype) - ) - out = hlo.BroadcastInDimOp( - ir.RankedTensorType.get( - val.shape, mge_dtype_to_ir_type(collapsed_val.dtype) - ), - _numpy_array_constant(collapsed_val, canonicalize_types=False)[0], - dense_int_elements(other_axes), - ).result - return (out,) - else: - return _numpy_array_constant(val, canonicalize_types) - - -_constant_handlers[np.ndarray] = _ndarray_constant_handler -for _scalar_type in [ - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.float16, - np.float32, - np.float64, - np.complex64, - np.complex128, - np.bool_, - np.longlong, - dtype.bfloat16, -]: - _constant_handlers[_scalar_type] = _ndarray_constant_handler - - -def _python_scalar_constant_handler(dtype, val, canonicalize_dtypes): - return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes) - - -for pt, dt in dtype._python_scalar_dtype_to_npdtypes.items(): - _constant_handlers[pt] = partial(_python_scalar_constant_handler, dt) - - -def _mge_varinfo_constant_handler(val, canonicalize_dtypes): - assert isinstance(val, VarInfo) - assert val.bound_data is not None and val.kind == "const" - assert isinstance(val.bound_data, np.ndarray) - return _numpy_array_constant( - np.asarray(val.bound_data, val.dtype), canonicalize_dtypes - ) - - -_constant_handlers[VarInfo] = _mge_varinfo_constant_handler - - -def ir_constant_tuple(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]: - for t in type(val).__mro__: - handler = _constant_handlers.get(t) - if handler: - out = handler(val, canonicalize_types) - assert all(isinstance(v, ir.Value) for v in out), (type(val), out) - return out - assert False - - -def ir_constant(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]: - values = ir_constant_tuple(val, canonicalize_types=canonicalize_types) - assert len(values) == 1 - return values[0] - - -def token_type() -> Sequence[ir.Type]: - return [hlo.TokenType.get()] - - -def dummy_token_type_tuple() -> Sequence[ir.Type]: - return make_ir_type_according_meta_tuple((0,), np.bool_) - - -def dummy_token() -> Sequence[ir.Value]: - return ir_constant_tuple(np.zeros(0, np.bool_)) - - -def i32_attr(i): - return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) - - -def i64_attr(i): - return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) - - -def ui64_attr(i): - return ir.IntegerAttr.get(ir.IntegerType.get_unsigned(64), i) - - -def f32_attr(i): - return ir.FloatAttr.get(ir.F32Type.get(), i) - - -def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr: - lhs_prec = str(lhs_prec) - rhs_prec = str(rhs_prec) - - assert lhs_prec == "float32" - assert rhs_prec == "float32" - - dtype_to_precision = { - "float32": "DEFAULT", - } - precision = (dtype_to_precision[lhs_prec], dtype_to_precision[rhs_prec]) - return ir.ArrayAttr.get([hlo.PrecisionAttr.get(p) for p in precision]) - - -def dense_int_elements(xs) -> ir.DenseIntElementsAttr: - return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) - - -def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr: - a = np.packbits(np.array(xs, np.bool_), bitorder="little") - if len(xs) == 1: - a = np.array(0 if a.item() == 0 else 0xFF, np.uint8) - return ir.DenseElementsAttr.get( - a, type=ir.IntegerType.get_signless(1), shape=[len(xs)] - ) - - -def get_irnode_shape(irnode): - if isinstance(irnode, (list, tuple, ir.OpResultList)): - assert len(irnode) == 1 - irnode = irnode[0] - assert isinstance(irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult)) - if not isinstance(irnode, ir.RankedTensorType): - irnode = ir.RankedTensorType(irnode.type) - return tuple(irnode.shape) - - -def get_irnode_dtype(irnode): - if isinstance(irnode, (list, tuple, ir.OpResultList)): - assert len(irnode) == 1 - irnode = irnode[0] - assert isinstance( - irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult) - ), type(irnode) - if not isinstance(irnode, ir.RankedTensorType): - irnode = ir.RankedTensorType(irnode.type) - etype = irnode.element_type - - for k, v in _dtype_to_ir_type.items(): - if etype == v(): - return k - - assert False, f"unknown irnode {irnode}" - - -def module_to_string(module: ir.Module) -> str: - output = io.StringIO() - module.operation.print( - file=output, enable_debug_info=True, print_generic_op_form=False - ) - return output.getvalue() - - -def module_to_bytecode(module: ir.Module) -> bytes: - output = io.BytesIO() - module.operation.write_bytecode(file=output) - return output.getvalue() diff --git a/imperative/python/megengine/xla/lib/__init__.py b/imperative/python/megengine/xla/lib/__init__.py deleted file mode 100644 index b1caaf615..000000000 --- a/imperative/python/megengine/xla/lib/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import platform -import re -import warnings -from typing import Optional, Tuple - -import jaxlib.cpu_feature_guard as cpu_feature_guard -import jaxlib.ducc_fft as ducc_fft -import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error -import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error -import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error -import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error -import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error -import jaxlib.lapack as lapack -import jaxlib.xla_client as xla_client - -try: - import jaxlib as jaxlib -except ModuleNotFoundError as err: - raise ModuleNotFoundError( - "megengine with xla requires jaxlib to be installed." - ) from err - -# some version check code -""" -import jax.version -from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str -try: - import jaxlib.version -except Exception as err: - # jaxlib is too old to have version number. - msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.' - raise ImportError(msg) from err - - -# Checks the jaxlib version before importing anything else from jaxlib. -# Returns the jaxlib version string. -def check_jaxlib_version(jax_version: str, jaxlib_version: str, - minimum_jaxlib_version: str): - # Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version. - # PEP440 allows a number of non-numeric suffixes, which we allow also. - # We currently do not allow an epoch. - version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*") - def _parse_version(v: str) -> Tuple[int, ...]: - m = version_regex.match(v) - if m is None: - raise ValueError(f"Unable to parse jaxlib version '{v}'") - return tuple(int(x) for x in m.group(0).split('.')) - - _jax_version = _parse_version(jax_version) - _minimum_jaxlib_version = _parse_version(minimum_jaxlib_version) - _jaxlib_version = _parse_version(jaxlib_version) - - if _jaxlib_version < _minimum_jaxlib_version: - msg = (f'jaxlib is version {jaxlib_version}, but this version ' - f'of jax requires version >= {minimum_jaxlib_version}.') - raise RuntimeError(msg) - - if _jaxlib_version > _jax_version: - msg = (f'jaxlib version {jaxlib_version} is newer than and ' - f'incompatible with jax version {jax_version}. Please ' - 'update your jax and/or jaxlib packages.') - raise RuntimeError(msg) - - return _jaxlib_version - -version_str = jaxlib.version.__version__ -version = check_jaxlib_version( - jax_version=jax.version.__version__, - jaxlib_version=jaxlib.version.__version__, - minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -""" - -# Before importing any C compiled modules from jaxlib, first import the CPU -# feature guard module to verify that jaxlib was compiled in a way that only -# uses instructions that are present on this machine. -cpu_feature_guard.check_cpu_features() - - -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib - - -# Jaxlib code is split between the Jax and the Tensorflow repositories. -# Only for the internal usage of the JAX developers, we expose a version -# number that can be used to perform changes without breaking the main -# branch on the Jax github. -xla_extension_version = getattr(xla_client, "_version", 0) - - -# Version number for MLIR:Python APIs, provided by jaxlib. -mlir_api_version = xla_client.mlir_api_version - -try: - from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error -except: - tpu_driver_client = None # type: ignore - - -# TODO: check if we need the same for rocm. -cuda_path: Optional[str] -cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda") -if not os.path.isdir(cuda_path): - cuda_path = None - -transfer_guard_lib = xla_client._xla.transfer_guard_lib diff --git a/imperative/python/megengine/xla/lib/config.py b/imperative/python/megengine/xla/lib/config.py deleted file mode 100644 index cd8804cb4..000000000 --- a/imperative/python/megengine/xla/lib/config.py +++ /dev/null @@ -1,1330 +0,0 @@ -import contextlib -import functools -import itertools -import logging -import os -import sys -import threading -from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional - -import jaxlib.xla_client as xla_client - -jax_jit = xla_client._xla.jax_jit -transfer_guard_lib = xla_client._xla.transfer_guard_lib - - -logger = logging.getLogger(__name__) - - -def bool_env(varname: str, default: bool) -> bool: - val = os.getenv(varname, str(default)) - val = val.lower() - if val in ("y", "yes", "t", "true", "on", "1"): - return True - elif val in ("n", "no", "f", "false", "off", "0"): - return False - else: - raise ValueError(f"invalid truth value {val!r} for environment {varname!r}") - - -def int_env(varname: str, default: int) -> int: - return int(os.getenv(varname, str(default))) - - -class Config: - _HAS_DYNAMIC_ATTRIBUTES = True - - def __init__(self): - self.values = {} - self.meta = {} - self.FLAGS = NameSpace(self.read, self.update) - self.use_absl = False - self._contextmanager_flags = set() - self._update_hooks = {} - - def update(self, name, val): - self.check_exists(name) - if name not in self.values: - raise Exception(f"Unrecognized config option: {name}") - self.values[name] = val - - hook = self._update_hooks.get(name, None) - if hook: - hook(val) - - def read(self, name): - if name in self._contextmanager_flags: - raise AttributeError( - "For flags with a corresponding contextmanager, read their value " - f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`." - ) - return self._read(name) - - def _read(self, name): - try: - return self.values[name] - except KeyError: - raise AttributeError(f"Unrecognized config option: {name}") - - def add_option( - self, name, default, opt_type, meta_args, meta_kwargs, update_hook=None - ): - if name in self.values: - raise Exception(f"Config option {name} already defined") - self.values[name] = default - self.meta[name] = (opt_type, meta_args, meta_kwargs) - if update_hook: - self._update_hooks[name] = update_hook - update_hook(default) - - def check_exists(self, name): - if name not in self.values: - raise AttributeError(f"Unrecognized config option: {name}") - - def DEFINE_bool(self, name, default, *args, **kwargs): - update_hook = kwargs.pop("update_hook", None) - self.add_option(name, default, bool, args, kwargs, update_hook=update_hook) - - def DEFINE_integer(self, name, default, *args, **kwargs): - update_hook = kwargs.pop("update_hook", None) - self.add_option(name, default, int, args, kwargs, update_hook=update_hook) - - def DEFINE_float(self, name, default, *args, **kwargs): - update_hook = kwargs.pop("update_hook", None) - self.add_option(name, default, float, args, kwargs, update_hook=update_hook) - - def DEFINE_string(self, name, default, *args, **kwargs): - update_hook = kwargs.pop("update_hook", None) - self.add_option(name, default, str, args, kwargs, update_hook=update_hook) - - def DEFINE_enum(self, name, default, *args, **kwargs): - update_hook = kwargs.pop("update_hook", None) - self.add_option(name, default, "enum", args, kwargs, update_hook=update_hook) - - def config_with_absl(self): - # Run this before calling `app.run(main)` etc - import absl.flags as absl_FLAGS # noqa: F401 - from absl import app, flags as absl_flags - - self.use_absl = True - self.absl_flags = absl_flags - absl_defs = { - bool: absl_flags.DEFINE_bool, - int: absl_flags.DEFINE_integer, - float: absl_flags.DEFINE_float, - str: absl_flags.DEFINE_string, - "enum": absl_flags.DEFINE_enum, - } - - for name, val in self.values.items(): - flag_type, meta_args, meta_kwargs = self.meta[name] - absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) - app.call_after_init(lambda: self.complete_absl_config(absl_flags)) - - def complete_absl_config(self, absl_flags): - for name, _ in self.values.items(): - flag = absl_flags.FLAGS[name] - if flag.present: - self.update(name, flag.value) - - def parse_flags_with_absl(self): - global already_configured_with_absl - if not already_configured_with_absl: - # Extract just the --jax... flags (before the first --) from argv. In some - # environments (e.g. ipython/colab) argv might be a mess of things - # parseable by absl and other junk. - jax_argv = itertools.takewhile(lambda a: a != "--", sys.argv) - jax_argv = ["", *(a for a in jax_argv if a.startswith("--jax"))] - - import absl.flags - - self.config_with_absl() - absl.flags.FLAGS(jax_argv, known_only=True) - self.complete_absl_config(absl.flags) - already_configured_with_absl = True - - def define_bool_state( - self, - name: str, - default: bool, - help: str, - *, - update_global_hook: Optional[Callable[[bool], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None, - upgrade: bool = False, - extra_description: str = "" - ): - """Set up thread-local state and return a contextmanager for managing it. - - This function is a convenience wrapper. It defines a flag, environment - variable, and corresponding thread-local state, which can be managed via the - contextmanager it returns. - - The thread-local state value can be read via the ``config.`` - attribute, where ``config`` is the singleton ``Config`` instance. - - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - default: boolean, a default value for the option. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - update_global_hook: a optional callback that is called with the updated - value of the global state when it is altered or set initially. - update_thread_local_hook: a optional callback that is called with the - updated value of the thread-local state when it is altered or set - initially. - upgrade: optional indicator that this flag controls a canonical feature - upgrade, so that it is `True` for the incoming functionality, `False` - for the outgoing functionality to be deprecated. - extra_description: string, optional: extra information to add to the - summary description. - - Returns: - A contextmanager to control the thread-local state value. - - Example: - - enable_foo = config.define_bool_state( - name='jax_enable_foo', - default=False, - help='Enable foo.') - - # Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo - # command-line flag can be used to control the process-level value of - # the configuration option, in addition to using e.g. - # ``config.update("jax_enable_foo", True)`` directly. We can also use a - # context manager: - - with enable_foo(True): - ... - - The value of the thread-local state or flag can be accessed via - ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is - an error. - - """ - name = name.lower() - self.DEFINE_bool( - name, bool_env(name.upper(), default), help, update_hook=update_global_hook - ) - self._contextmanager_flags.add(name) - - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - - setattr(Config, name, property(get_state)) - - return _StateContextManager( - name, - help, - update_thread_local_hook, - extra_description=extra_description, - default_value=True, - ) - - def define_enum_state( - self, - name: str, - enum_values: List[str], - default: Optional[str], - help: str, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, - ): - """Set up thread-local state and return a contextmanager for managing it. - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - enum_values: list of strings representing the possible values for the - option. - default: optional string, default value. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - Returns: - A contextmanager to control the thread-local state value. - See docstring for ``define_bool_state``. - """ - name = name.lower() - default = os.getenv(name.upper(), default) - if default is not None and default not in enum_values: - raise ValueError(f'Invalid value "{default}" for JAX flag {name}') - self.DEFINE_enum( - name, - default, - enum_values=enum_values, - help=help, - update_hook=update_global_hook, - ) - self._contextmanager_flags.add(name) - - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - - setattr(Config, name, property(get_state)) - - def validate(new_val): - if new_val is not None and ( - type(new_val) is not str or new_val not in enum_values - ): - raise ValueError( - f"new enum value must be None or in {enum_values}, " - f"got {new_val} of type {type(new_val)}." - ) - - return _StateContextManager(name, help, update_thread_local_hook, validate) - - def define_int_state( - self, - name: str, - default: Optional[int], - help: str, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, - ): - """Set up thread-local state and return a contextmanager for managing it. - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - enum_values: list of strings representing the possible values for the - option. - default: optional int, default value. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - Returns: - A contextmanager to control the thread-local state value. - See docstring for ``define_bool_state``. - """ - name = name.lower() - default_env = os.getenv(name.upper(), default) - if default_env is not None: - try: - default = int(default_env) - except ValueError: - raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') - self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook) - self._contextmanager_flags.add(name) - - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - - setattr(Config, name, property(get_state)) - - def validate(new_val): - if new_val is not None and not isinstance(new_val, int): - raise ValueError( - f"new int config value must be None or of type int, " - f"got {new_val} of type {type(new_val)}" - ) - - return _StateContextManager(name, help, update_thread_local_hook, validate) - - def define_float_state( - self, - name: str, - default: Optional[float], - help: str, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, - ): - """Set up thread-local state and return a contextmanager for managing it. - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - enum_values: list of strings representing the possible values for the - option. - default: optional float, default value. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - Returns: - A contextmanager to control the thread-local state value. - See docstring for ``define_bool_state``. - """ - name = name.lower() - default_env = os.getenv(name.upper(), default) - if default_env is not None: - try: - default = float(default_env) - except ValueError: - raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') - self.DEFINE_float(name, default, help=help, update_hook=update_global_hook) - self._contextmanager_flags.add(name) - - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - - setattr(Config, name, property(get_state)) - - def validate(new_val): - if new_val is not None and not isinstance(new_val, (float, int)): - raise ValueError( - f"new float config value must be None or of type float, " - f"got {new_val} of type {type(new_val)}" - ) - - return _StateContextManager(name, help, update_thread_local_hook, validate) - - def define_string_state( - self, - name: str, - default: Optional[str], - help: str, - update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, - ): - """Set up thread-local state and return a contextmanager for managing it. - - See docstring for ``define_bool_state``. - - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - default: string, a default value for the option. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - update_global_hook: an optional callback that is called with the updated - value of the global state when it is altered or set initially. - update_thread_local_hook: an optional callback that is called with the - updated value of the thread-local state when it is altered or set - initially. - - Returns: - A contextmanager to control the thread-local state value. - """ - - def validate(new_val): - if new_val is not None and not isinstance(new_val, str): - raise ValueError( - f"new string config value must be None or of type str," - f" got {new_val} of type {type(new_val)}." - ) - - return self.define_string_or_object_state( - name, default, help, update_global_hook, update_thread_local_hook, validate - ) - - def define_string_or_object_state( - self, - name: str, - default: Any, - help: str, - update_global_hook: Optional[Callable[[Any], None]] = None, - update_thread_local_hook: Optional[Callable[[Any], None]] = None, - validate_new_val_hook: Optional[Callable[[Any], None]] = None, - ): - """Set up thread-local state and return a contextmanager for managing it. - - Similar to ``define_string_state``, except the context manager will accept - any object, not just a string. Any value passed via commandline flag or - environment variable will be treated as a string. - - Args: - name: string, converted to lowercase to define the name of the config - option (and absl flag). It is converted to uppercase to define the - corresponding shell environment variable. - default: string, a default value for the option. - help: string, used to populate the flag help information as well as the - docstring of the returned context manager. - update_global_hook: an optional callback that is called with the updated - value of the global state when it is altered or set initially. - update_thread_local_hook: an optional callback that is called with the - updated value of the thread-local state when it is altered or set - initially. - validate_new_val_hook: an optional callback that is called with the new - value on any update, and should raise an error if the new value is - invalid. - - Returns: - A contextmanager to control the thread-local state value. - """ - name = name.lower() - default = os.getenv(name.upper(), default) - self.DEFINE_string(name, default, help=help, update_hook=update_global_hook) - self._contextmanager_flags.add(name) - - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - - setattr(Config, name, property(get_state)) - - return _StateContextManager( - name, help, update_thread_local_hook, validate_new_val_hook - ) - - def _trace_context(self): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately.""" - tls = jax_jit.thread_local_state() - axis_env_state = () - context = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - return ( - axis_env_state, - self.x64_enabled, - self.jax_numpy_rank_promotion, - self.jax_default_matmul_precision, - self.jax_dynamic_shapes, - self.jax_numpy_dtype_promotion, - self.jax_default_device, - self.jax_array, - self.jax_threefry_partitionable, - ) - - -class NoDefault: - pass - - -no_default = NoDefault() - - -class _StateContextManager: - def __init__( - self, - name, - help, - update_thread_local_hook, - validate_new_val_hook: Optional[Callable[[Any], None]] = None, - extra_description: str = "", - default_value: Any = no_default, - ): - self._name = name - self.__name__ = name[4:] if name.startswith("jax_") else name - self.__doc__ = ( - f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}" - ) - self._update_thread_local_hook = update_thread_local_hook - self._validate_new_val_hook = validate_new_val_hook - self._default_value = default_value - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_value is not no_default: - new_val = self._default_value # default_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError( - f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option." - ) - if self._validate_new_val_hook: - self._validate_new_val_hook(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: - self._update_thread_local_hook(None) - else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(prev_val) - - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. - - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - config._update_hooks[self._name] = update_global_hook - update_global_hook(config._read(self._name)) - - -_thread_local_state = threading.local() - - -class _Unset: - pass - - -unset = _Unset() - - -class NameSpace: - def __init__(self, getter, setter): - # must use super because we override this class's __setattr__, see - # https://docs.python.org/3/reference/datamodel.html#object.__setattr__ - super().__setattr__("_getter", getter) - super().__setattr__("_setter", setter) - - def __getattr__(self, name): - return self._getter(name) - - def __setattr__(self, name, val): - self._setter(name, val) - - -config = Config() -flags = config -FLAGS = flags.FLAGS - -already_configured_with_absl = False - - -# The C++ JIT maintains its own copy of several configuration items as -# a global/thread-local state. These methods allow updates to part of the -# state when a configuration value changes. -class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: Optional[str] = None - numpy_dtype_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None - dynamic_shapes: bool = False - threefry_partitionable: bool = False - - -def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - -class _ThreadLocalExtraJitContext(NamedTuple): - """"A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - - dynamic_trace_state: Optional[Any] = None - axis_env_state: Hashable = () - numpy_rank_promotion: Optional[str] = None - numpy_dtype_promotion: Optional[str] = None - default_matmul_precision: Optional[Any] = None - dynamic_shapes: bool = False - - -class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to duduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) - - -_thread_local_state_cache = _ThreadLocalStateCache() - - -def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) - - -flags.DEFINE_integer( - "jax_tracer_error_num_traceback_frames", - int_env("JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5), - help="Set the number of stack frames in JAX tracer error messages.", -) - -flags.DEFINE_bool( - "jax_pprint_use_color", - bool_env("JAX_PPRINT_USE_COLOR", True), - help="Enable jaxpr pretty-printing with colorful syntax highlighting.", -) - -flags.DEFINE_bool( - "jax_host_callback_inline", - bool_env("JAX_HOST_CALLBACK_INLINE", False), - help="Inline the host_callback, if not in a staged context.", -) -flags.DEFINE_integer( - "jax_host_callback_max_queue_byte_size", - int_env("JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)), - help=( - "The size in bytes of the buffer used to hold outfeeds from each " - "device. When this capacity is reached consuming outfeeds from the " - "device is paused, thus potentially pausing the device computation, " - "until the Python callback consume more outfeeds." - ), - lower_bound=int(16 * 1e6), -) -flags.DEFINE_bool( - "jax_host_callback_outfeed", - bool_env("JAX_HOST_CALLBACK_OUTFEED", False), - help=( - "Use outfeed implementation for host_callback, even on CPU and GPU. " - "If false, use the CustomCall implementation. " - "Has no effect on TPU, since only the outfeed mechanism is implemented." - ), -) -flags.DEFINE_bool( - "jax_host_callback_ad_transforms", - bool_env("JAX_HOST_CALLBACK_AD_TRANSFORMS", False), - help=( - "Enable support for jvp/vjp for the host_callback primitives. Default is " - "False, which means that host_callback operates only on primals. " - "The flag exists only temporarily, for backward compatibility." - ), -) - -# # TODO: remove flag when XLA:CPU is improved. -# jax2tf_associative_scan_reductions = config.define_bool_state( -# name="jax2tf_associative_scan_reductions", -# default=False, -# help=( -# "JAX has two separate lowering rules for the cumulative reduction " -# "primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses " -# "a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. " -# "The latter has a slow implementation on CPUs and GPUs. " -# "By default, jax2tf uses the TPU lowering. Set this flag to True to " -# "use the associative scan lowering usage, and only if it makes a difference " -# "for your application. " -# "See the jax2tf README.md for more details." -# ), -# ) - -# jax2tf_default_experimental_native_lowering = config.define_bool_state( -# name="jax2tf_default_experimental_native_lowering", -# default=bool_env("JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False), -# help=( -# "DO NOT USE, highly experimental. Sets the default value of the " -# "experimental_native_lowering parameter to jax2tf.convert." -# ), -# ) - -# jax2tf_use_stablehlo = config.define_bool_state( -# name="jax2tf_use_stablehlo", -# default=bool_env("JAX2TF_USE_STABLEHLO", True), -# help=( -# "DO NOT USE, highly experimental. Use in conjunction with jax2tf " -# "experimental_native_lowering, to use StableHLO instead of MHLO as " -# "the serialization format." -# ), -# ) - -jax_platforms = config.define_string_state( - name="jax_platforms", - default=None, - help=( - "Comma-separated list of platform names specifying which platforms jax " - "should initialize. If any of the platforms in this list are not successfully " - "initialized, an exception will be raised and the program will be aborted. " - "The first platform in the list will be the default platform. " - "For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends " - "will be initialized, and the CPU backend will be used unless otherwise " - "specified. If TPU initialization fails, it will raise an exception. " - "By default, jax will try to initialize all available " - "platforms and will default to GPU or TPU if available, and fallback to CPU " - "otherwise." - ), -) - -# enable_checks = config.define_bool_state( -# name="jax_enable_checks", -# default=False, -# help="Turn on invariant checking for JAX internals. Makes things slower.", -# ) - -# check_tracer_leaks = config.define_bool_state( -# name="jax_check_tracer_leaks", -# default=False, -# help=( -# "Turn on checking for leaked tracers as soon as a trace completes. " -# "Enabling leak checking may have performance impacts: some caching " -# "is disabled, and other overheads may be added. Additionally, be aware " -# "that some Python debuggers can cause false positives, so it is recommended " -# "to disable any debuggers while leak checking is enabled." -# ), -# ) -# checking_leaks = functools.partial(check_tracer_leaks, True) - -# debug_nans = config.define_bool_state( -# name="jax_debug_nans", -# default=False, -# help=( -# "Add nan checks to every operation. When a nan is detected on the " -# "output of a jit-compiled computation, call into the un-compiled " -# "version in an attempt to more precisely identify the operation " -# "which produced the nan." -# ), -# ) - -# debug_infs = config.define_bool_state( -# name="jax_debug_infs", -# default=False, -# help=( -# "Add inf checks to every operation. When an inf is detected on the " -# "output of a jit-compiled computation, call into the un-compiled " -# "version in an attempt to more precisely identify the operation " -# "which produced the inf." -# ), -# ) - -# log_compiles = config.define_bool_state( -# name="jax_log_compiles", -# default=False, -# help=( -# "Log a message each time every time `jit` or `pmap` compiles an XLA " -# "computation. Logging is performed with `logging`. When this " -# "option is set, the log level is WARNING; otherwise the level is " -# "DEBUG." -# ), -# ) - -# parallel_functions_output_gda = config.define_bool_state( -# name="jax_parallel_functions_output_gda", -# default=False, -# help="If True, pjit will output GDAs.", -# ) - - -# def _update_jax_array_global(val): -# if val is not None and not val: -# raise ValueError( -# "jax.config.jax_array cannot be disabled after jax 0.4.7 release." -# " Please downgrade to jax and jaxlib 0.4.6 if you want to disable" -# " jax.config.jax_array." -# ) - - -# def _update_jax_array_thread_local(val): -# if val is not None and not val: -# raise ValueError( -# "jax.config.jax_array cannot be disabled after jax 0.4.7 release." -# " Please downgrade to jax and jaxlib 0.4.6 if you want to disable" -# " jax.config.jax_array." -# ) - - -# jax_array = config.define_bool_state( -# name="jax_array", -# default=True, -# upgrade=True, -# update_global_hook=_update_jax_array_global, -# update_thread_local_hook=_update_jax_array_thread_local, -# help=( -# "If True, new pjit behavior will be enabled and `jax.Array` will be " "used." -# ), -# ) - - -# jit_pjit_api_merge = config.define_bool_state( -# name="jax_jit_pjit_api_merge", -# default=False, -# upgrade=True, -# help=("If True, jit and pjit API will be merged."), -# ) - - -# spmd_mode = config.define_enum_state( -# name="jax_spmd_mode", -# enum_values=["allow_all", "allow_jit", "allow_pjit"], -# # TODO: Default to `allow_jit` when the training wheels come -# # off. -# default="allow_pjit", -# help=( -# "Decides whether Math on `jax.Array`'s that are not fully addressable " -# "(i.e. spans across multiple processes) is allowed. The options are: " -# "* allow_pjit: Default, only `pjit` computations are allowed to " -# " execute on non-fully addressable `jax.Array`s\n" -# "* allow_jit: `pjit` and `jax.jit` computations are allowed to " -# " execute on non-fully addressable `jax.Array`s\n" -# "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " -# " `jax.jit` and all other operations are allowed to " -# " execute on non-fully addresable `jax.Array`s." -# ), -# ) - - -# distributed_debug = config.define_bool_state( -# name="jax_distributed_debug", -# default=False, -# help=( -# "Enable logging useful for debugging multi-process distributed " -# "computations. Logging is performed with `logging` at WARNING " -# "level." -# ), -# ) - - -# enable_custom_prng = config.define_bool_state( -# name="jax_enable_custom_prng", -# default=False, -# upgrade=True, -# help=( -# "Enables an internal upgrade that allows one to define custom " -# "pseudo-random number generator implementations." -# ), -# ) - -# default_prng_impl = config.define_enum_state( -# name="jax_default_prng_impl", -# enum_values=["threefry2x32", "rbg", "unsafe_rbg"], -# default="threefry2x32", -# help=( -# "Select the default PRNG implementation, used when one is not " -# "explicitly provided at seeding time." -# ), -# ) - -# threefry_partitionable = config.define_bool_state( -# name="jax_threefry_partitionable", -# default=False, -# upgrade=True, -# help=( -# "Enables internal threefry PRNG implementation changes that " -# "render it automatically partitionable in some cases. For use " -# "with pjit and/or jax_array=True. Without this flag, using the " -# "standard jax.random pseudo-random number generation may result " -# "in extraneous communication and/or redundant distributed " -# "computation. With this flag, the communication overheads disappear " -# "in some cases." -# ), -# update_global_hook=lambda val: _update_global_jit_state(threefry_partitionable=val), -# update_thread_local_hook=lambda val: update_thread_local_jit_state( -# threefry_partitionable=val -# ), -# ) - -# enable_custom_vjp_by_custom_transpose = config.define_bool_state( -# name="jax_enable_custom_vjp_by_custom_transpose", -# default=False, -# upgrade=True, -# help=( -# "Enables an internal upgrade that implements `jax.custom_vjp` by " -# "reduction to `jax.custom_jvp` and `jax.custom_transpose`." -# ), -# ) - -# raise_persistent_cache_errors = config.define_bool_state( -# name="jax_raise_persistent_cache_errors", -# default=False, -# help=( -# "If true, exceptions raised when reading or writing to the " -# "persistent compilation cache will be allowed through, halting " -# "program execution if not manually caught. If false, exceptions are " -# "caught and raised as warnings, allowing program execution to " -# "continue. Defaults to false so cache bugs or intermittent issues " -# "are non-fatal." -# ), -# ) - -# persistent_cache_min_compile_time_secs = config.define_float_state( -# name="jax_persistent_cache_min_compile_time_secs", -# default=1, -# help=( -# "The minimum compile time of a computation to be written to the " -# "persistent compilation cache. This threshold can be raised to " -# "decrease the number of entries written to the cache." -# ), -# ) - -# hlo_source_file_canonicalization_regex = config.define_string_state( -# name="jax_hlo_source_file_canonicalization_regex", -# default=None, -# help=( -# "Used to canonicalize the source_path metadata of HLO instructions " -# "by removing the given regex. If set, re.sub() is called on each " -# "source_file with the given regex, and all matches are removed. " -# "This can be used to avoid spurious cache misses when using the " -# "persistent compilation cache, which includes HLO metadata in the " -# "cache key." -# ), -# ) - -# config.define_enum_state( -# name="jax_default_dtype_bits", -# enum_values=["32", "64"], -# default="64", -# help=( -# "Specify bit width of default dtypes, either 32-bit or 64-bit. " -# "This is a temporary flag that will be used during the process " -# "of deprecating the ``jax_enable_x64`` flag." -# ), -# ) - -# numpy_dtype_promotion = config.define_enum_state( -# name="jax_numpy_dtype_promotion", -# enum_values=["standard", "strict"], -# default="standard", -# help=( -# "Specify the rules used for implicit type promotion in operations " -# 'between arrays. Options are "standard" or "strict"; in strict-mode, ' -# "binary operations between arrays of differing strongly-specified " -# "dtypes will result in an error." -# ), -# update_global_hook=lambda val: _update_global_jit_state(numpy_dtype_promotion=val), -# update_thread_local_hook=lambda val: update_thread_local_jit_state( -# numpy_dtype_promotion=val -# ), -# ) - - -def _update_x64_global(val): - jax_jit.global_state().enable_x64 = val - - -def _update_x64_thread_local(val): - jax_jit.thread_local_state().enable_x64 = val - - -enable_x64 = config.define_bool_state( - name="jax_enable_x64", - default=False, - help="Enable 64-bit types to be used", - update_global_hook=_update_x64_global, - update_thread_local_hook=_update_x64_thread_local, -) - -# TODO: remove after fixing users of FLAGS.x64_enabled. -config._contextmanager_flags.remove("jax_enable_x64") - -Config.x64_enabled = Config.jax_enable_x64 # type: ignore - - -def _update_default_device_global(val): - jax_jit.global_state().default_device = val - - -def _update_default_device_thread_local(val): - jax_jit.thread_local_state().default_device = val - - -def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): - # TODO: this is a workaround for non-PJRT Device types. Remove when - # all JAX backends use a single C++ device interface. - if "Device" in str(type(val)): - logger.info( - "Allowing non-`xla_client.Device` default device: %s, type: %s", - repr(val), - type(val), - ) - return - raise ValueError( - "jax.default_device must be passed a Device object (e.g. " - f"`jax.devices('cpu')[0]`), got: {repr(val)}" - ) - - -# TODO: default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). -default_device = config.define_string_or_object_state( - name="jax_default_device", - default=None, - help=( - "Configure the default device for JAX operations. Set to a Device " - 'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the ' - "default device for JAX operations and jit'd function calls (there is " - "no effect on multi-device computations, e.g. pmapped function calls). " - "Set to None to use the system default device. See " - ":ref:`faq-data-placement` for more information on device placement." - ), - update_global_hook=_update_default_device_global, - update_thread_local_hook=_update_default_device_thread_local, - validate_new_val_hook=_validate_default_device, -) - - -def _update_disable_jit_global(val): - jax_jit.global_state().disable_jit = val - - -def _update_disable_jit_thread_local(val): - jax_jit.thread_local_state().disable_jit = val - - -disable_jit = config.define_bool_state( - name="jax_disable_jit", - default=False, - help=("Disable JIT compilation and just call original Python."), - update_global_hook=_update_disable_jit_global, - update_thread_local_hook=_update_disable_jit_thread_local, -) - - -numpy_rank_promotion = config.define_enum_state( - name="jax_numpy_rank_promotion", - enum_values=["allow", "warn", "raise"], - default="allow", - help=( - "Control NumPy-style automatic rank promotion broadcasting " - '("allow", "warn", or "raise").' - ), - update_global_hook=lambda val: _update_global_jit_state(numpy_rank_promotion=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - numpy_rank_promotion=val - ), -) - -default_matmul_precision = config.define_enum_state( - name="jax_default_matmul_precision", - enum_values=["bfloat16", "tensorfloat32", "float32"], - default=None, - help=( - "Control the default matmul and conv precision for 32bit inputs.\n\n" - "Some platforms, like TPU, offer configurable precision levels for " - "matrix multiplication and convolution computations, trading off " - "accuracy for speed. The precision can be controlled for each " - "operation; for example, see the :func:`jax.lax.conv_general_dilated` " - "and :func:`jax.lax.dot` docstrings. But it can be useful to control " - "the default behavior obtained when an operation is not given a " - "specific precision.\n\n" - "This option can be used to control the default precision " - "level for computations involved in matrix multiplication and " - "convolution on 32bit inputs. The levels roughly describe the " - "precision at which scalar products are computed. The 'bfloat16' " - "option is the fastest and least precise; 'float32' is similar to " - "full float32 precision; 'tensorfloat32' is intermediate.\n\n" - ), - update_global_hook=lambda val: _update_global_jit_state( - default_matmul_precision=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - default_matmul_precision=val - ), -) - -traceback_filtering = config.define_enum_state( - name="jax_traceback_filtering", - enum_values=["off", "tracebackhide", "remove_frames", "auto"], - default="auto", - help="Controls how JAX filters internal frames out of tracebacks.\n\n" - "Valid values are:\n" - ' * "off": disables traceback filtering.\n' - ' * "auto": use "tracebackhide" if running under a sufficiently ' - 'new IPython, or "remove_frames" otherwise.\n' - ' * "tracebackhide": adds "__tracebackhide__" annotations to ' - " hidden stack frames, which some traceback printers support.\n" - ' * "remove_frames": removes hidden frames from tracebacks, and adds ' - " the unfiltered traceback as a __cause__ of the exception.\n", -) - -# This flag is for internal use. -# TODO: Removes once we always enable cusparse lowering. -# TODO: Set to true after bug is fixed -bcoo_cusparse_lowering = config.define_bool_state( - name="jax_bcoo_cusparse_lowering", - default=False, - help=("Enables lowering BCOO ops to cuSparse."), -) - -# TODO: remove this flag when we ensure we only succeed at trace-staging -# if the intended backend can handle lowering the result -config.define_bool_state( - name="jax_dynamic_shapes", - default=bool(os.getenv("JAX_DYNAMIC_SHAPES", "")), - help=( - "Enables experimental features for staging out computations with " - "dynamic shapes." - ), - update_global_hook=lambda val: _update_global_jit_state(dynamic_shapes=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - dynamic_shapes=val - ), -) - -# TODO: Remove flag once coordination service has rolled out. -config.define_bool_state( - name="jax_coordination_service", - default=True, - help=( - "Use coordination service (experimental) instead of the default PjRT " - "distributed runtime." - ), -) - -config.define_bool_state( - name="jax_experimental_subjaxpr_lowering_cache", - default=False, - help="Enable using a cache for lowering subjaxprs.", -) - -# TODO: set default to True, then remove -config.define_bool_state( - name="jax_eager_pmap", - default=True, - upgrade=True, - help="Enable eager-mode pmap when jax_disable_jit is activated.", -) - -config.define_bool_state( - name="jax_experimental_unsafe_xla_runtime_errors", - default=False, - help=( - "Enable XLA runtime errors for jax.experimental.checkify.checks " - "on CPU and GPU. These errors are async, might get lost and are not " - "very readable. But, they crash the computation and enable you " - "to write jittable checks without needing to checkify. Does not " - "work under pmap/pjit." - ), -) - - -@contextlib.contextmanager -def explicit_device_put_scope() -> Iterator[None]: - """Indicates that the current context is an explicit device_put*() call.""" - state = transfer_guard_lib.thread_local_state() - prev = state.explicit_device_put - state.explicit_device_put = True - try: - yield - finally: - state.explicit_device_put = prev - - -@contextlib.contextmanager -def explicit_device_get_scope() -> Iterator[None]: - """Indicates that the current context is an explicit device_get() call.""" - state = transfer_guard_lib.thread_local_state() - prev = state.explicit_device_get - state.explicit_device_get = True - try: - yield - finally: - state.explicit_device_get = prev - - -def _update_transfer_guard(state, key, val): - """Applies the transfer guard level within transfer_guard_lib.""" - if val is None: - setattr(state, key, None) - elif val == "allow": - setattr(state, key, transfer_guard_lib.TransferGuardLevel.ALLOW) - elif val == "log": - setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG) - elif val == "disallow": - setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW) - elif val == "log_explicit": - setattr(state, key, transfer_guard_lib.TransferGuardLevel.LOG_EXPLICIT) - elif val == "disallow_explicit": - setattr(state, key, transfer_guard_lib.TransferGuardLevel.DISALLOW_EXPLICIT) - else: - assert False, f"Invalid transfer guard level {val}" - - -transfer_guard_host_to_device = config.define_enum_state( - name="jax_transfer_guard_host_to_device", - enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. - default=None, - help=( - "Select the transfer guard level for host-to-device transfers. " - 'Default is "allow".' - ), - update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), "host_to_device", val - ), - update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), "host_to_device", val - ), -) - -transfer_guard_device_to_device = config.define_enum_state( - name="jax_transfer_guard_device_to_device", - enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. - default=None, - help=( - "Select the transfer guard level for device-to-device transfers. " - 'Default is "allow".' - ), - update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), "device_to_device", val - ), - update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), "device_to_device", val - ), -) - -transfer_guard_device_to_host = config.define_enum_state( - name="jax_transfer_guard_device_to_host", - enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. - default=None, - help=( - "Select the transfer guard level for device-to-host transfers. " - 'Default is "allow".' - ), - update_global_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.global_state(), "device_to_host", val - ), - update_thread_local_hook=lambda val: _update_transfer_guard( - transfer_guard_lib.thread_local_state(), "device_to_host", val - ), -) - - -def _update_all_transfer_guard_global(val): - for name in ( - "jax_transfer_guard_host_to_device", - "jax_transfer_guard_device_to_device", - "jax_transfer_guard_device_to_host", - ): - config.update(name, val) - - -_transfer_guard = config.define_enum_state( - name="jax_transfer_guard", - enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], - # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard_*. - default=None, - help=( - "Select the transfer guard level for all transfers. This option is " - "set-only; the transfer guard level for a specific direction should " - "be read using the per-transfer direction option. " - 'Default is "allow".' - ), - update_global_hook=_update_all_transfer_guard_global, -) - - -@contextlib.contextmanager -def transfer_guard(new_val: str) -> Iterator[None]: - """A contextmanager to control the transfer guard level for all transfers. - - For more information, see - https://jax.readthedocs.io/en/latest/transfer_guard.html - - Args: - new_val: The new thread-local transfer guard level for all transfers. - - Yields: - None. - """ - with contextlib.ExitStack() as stack: - stack.enter_context(transfer_guard_host_to_device(new_val)) - stack.enter_context(transfer_guard_device_to_device(new_val)) - stack.enter_context(transfer_guard_device_to_host(new_val)) - stack.enter_context(_transfer_guard(new_val)) - yield diff --git a/imperative/python/megengine/xla/lib/mlir/__init__.py b/imperative/python/megengine/xla/lib/mlir/__init__.py deleted file mode 100644 index e2f1182d6..000000000 --- a/imperative/python/megengine/xla/lib/mlir/__init__.py +++ /dev/null @@ -1 +0,0 @@ -import jaxlib.mlir.ir as ir diff --git a/imperative/python/megengine/xla/lib/mlir/dialects/__init__.py b/imperative/python/megengine/xla/lib/mlir/dialects/__init__.py deleted file mode 100644 index 3ba818029..000000000 --- a/imperative/python/megengine/xla/lib/mlir/dialects/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import jaxlib.mlir.dialects.builtin as builtin -import jaxlib.mlir.dialects.chlo as chlo -import jaxlib.mlir.dialects.func as func -import jaxlib.mlir.dialects.mhlo as mhlo -import jaxlib.mlir.dialects.ml_program as ml_program -import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor -import jaxlib.mlir.dialects.stablehlo as stablehlo -import jaxlib.xla_client as xla_client - -# Alias that is set up to abstract away the transition from MHLO to StableHLO. -use_stablehlo = xla_client.mlir_api_version >= 42 -if use_stablehlo: - import jaxlib.mlir.dialects.stablehlo as hlo -else: - import jaxlib.mlir.dialects.mhlo as hlo # type: ignore[no-redef] diff --git a/imperative/python/megengine/xla/lib/xla_bridge.py b/imperative/python/megengine/xla/lib/xla_bridge.py deleted file mode 100644 index ae43643d4..000000000 --- a/imperative/python/megengine/xla/lib/xla_bridge.py +++ /dev/null @@ -1,503 +0,0 @@ -import logging -import os -import platform as py_platform -import threading -import warnings -from functools import lru_cache, partial -from typing import Any, Dict, List, Optional, Union - -import numpy as np -from jaxlib import xla_client - -from ..lib import cuda_path -from .config import bool_env, config, flags, int_env - -XlaBackend = xla_client._xla.Client - -ShardedBuffer = Any - -FLAGS = flags.FLAGS - -logger = logging.getLogger(__name__) - -flags.DEFINE_string( - "jax_xla_backend", "", "Deprecated, please use --jax_platforms instead." -) -flags.DEFINE_string( - "jax_backend_target", - os.getenv("JAX_BACKEND_TARGET", "").lower(), - 'Either "local" or "rpc:address" to connect to a remote service target.', -) -# TODO: warn when this is used once we test out --jax_platforms a bit -flags.DEFINE_string( - "jax_platform_name", - os.getenv("JAX_PLATFORM_NAME", "").lower(), - "Deprecated, please use --jax_platforms instead.", -) -flags.DEFINE_bool( - "jax_disable_most_optimizations", - bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False), - "Try not to do much optimization work. This can be useful if the cost of " - "optimization is greater than that of running a less-optimized program.", -) -flags.DEFINE_integer( - "jax_xla_profile_version", - int_env("JAX_XLA_PROFILE_VERSION", 0), - "Optional profile version for XLA compilation. " - "This is meaningful only when XLA is configured to " - "support the remote compilation profile feature.", -) -flags.DEFINE_string( - "jax_cuda_visible_devices", - "all", - 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' - "comma-separate list of integer device IDs.", -) -flags.DEFINE_string( - "jax_rocm_visible_devices", - "all", - 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' - "comma-separate list of integer device IDs.", -) - - -def get_compile_options( - num_replicas: int, - num_partitions: int, - device_assignment=None, - use_spmd_partitioning: bool = True, - use_auto_spmd_partitioning: bool = False, - auto_spmd_partitioning_mesh_shape=[], - auto_spmd_partitioning_mesh_ids=[], -) -> xla_client.CompileOptions: - """Returns the compile options to use, as derived from flag values. - - Args: - num_replicas: Number of replicas for which to compile. - num_partitions: Number of partitions for which to compile. - device_assignment: Optional ndarray of jax devices indicating the assignment - of logical replicas to physical devices (default inherited from - xla_client.CompileOptions). Must be consistent with `num_replicas` and - `num_partitions`. - use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD - partitioning in XLA. - use_auto_spmd_partitioning: boolean indicating whether to automatically - generate XLA shardings for SPMD partitioner. - auto_spmd_partitioning_mesh_shape: device mesh shape used to create - auto_spmd_partitioning search space. - auto_spmd_partitioning_mesh_ids: device ids used to create - auto_spmd_partitioning search space. - """ - compile_options = xla_client.CompileOptions() - compile_options.num_replicas = num_replicas - compile_options.num_partitions = num_partitions - build_options = compile_options.executable_build_options - build_options.use_spmd_partitioning = use_spmd_partitioning - build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - if use_auto_spmd_partitioning: - build_options.auto_spmd_partitioning_mesh_shape = ( - auto_spmd_partitioning_mesh_shape - ) - build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids - if device_assignment is not None: - logger.debug( - "get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s", - num_replicas, - num_partitions, - device_assignment, - ) - device_assignment = np.array(device_assignment) - - # Allow 1D device assignment if num_partitions is 1. - if (device_assignment.ndim == 1) and (num_partitions == 1): - device_assignment = device_assignment[:, None] - - if num_replicas != device_assignment.shape[0]: - msg = "device_assignment does not match num_replicas: {} vs {}." - raise ValueError(msg.format(device_assignment, num_replicas)) - - if num_partitions != device_assignment.shape[1]: - msg = "device_assignment does not match num_partitions: {} vs {}." - raise ValueError(msg.format(device_assignment, num_partitions)) - - if device_assignment.dtype == object: - device_assignment = np.vectorize(lambda d: d.id, otypes=[int])( - device_assignment - ) - device_assignment = xla_client.DeviceAssignment.create(device_assignment) - assert device_assignment.replica_count() == num_replicas - assert device_assignment.computation_count() == num_partitions - compile_options.device_assignment = device_assignment - - debug_options = compile_options.executable_build_options.debug_options - if cuda_path is not None: - debug_options.xla_gpu_cuda_data_dir = cuda_path - - if FLAGS.jax_disable_most_optimizations: - debug_options.xla_backend_optimization_level = 0 - debug_options.xla_llvm_disable_expensive_passes = True - debug_options.xla_test_all_input_layouts = False - - compile_options.profile_version = FLAGS.jax_xla_profile_version - return compile_options - - -# Backends, in increasing order of preference. -# We have no particular opinion about how "backends" relate to "devices". For -# example, there could be multiple backends that provide the same kind of -# device. -_backend_factories = {} -_default_backend = None -_backends: Dict[str, Any] = {} -_backends_errors: Dict[str, str] = {} -_backend_lock = threading.Lock() - - -def register_backend_factory(name, factory, *, priority=0): - with _backend_lock: - if name in _backends: - raise RuntimeError(f"Backend {name} already initialized") - _backend_factories[name] = (factory, priority) - - -register_backend_factory( - "interpreter", xla_client.make_interpreter_client, priority=-100 -) -register_backend_factory( - "cpu", partial(xla_client.make_cpu_client, use_tfrt=True), priority=0 -) - - -def make_gpu_client(*, platform_name, visible_devices_flag): - from ..distribute import global_state - - visible_devices = global_state.visible_devices - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - else: - allowed_devices = None - return xla_client.make_gpu_client( - distributed_client=global_state.client, - node_id=global_state.process_id, - platform_name=platform_name, - allowed_devices=allowed_devices, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag="jax_cuda_visible_devices", - ), - priority=200, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag="jax_rocm_visible_devices", - ), - priority=200, - ) - -if hasattr(xla_client, "make_plugin_device_client"): - # It is assumed that if jax has been built with a plugin client, then the - # user wants to use the plugin client by default. Therefore, it gets the - # highest priority. - register_backend_factory( - "plugin", xla_client.make_plugin_device_client, priority=400 - ) - -_platform_aliases = { - "cuda": "gpu", - "rocm": "gpu", -} - -_alias_to_platforms: Dict[str, List[str]] = {} -for _platform, _alias in _platform_aliases.items(): - _alias_to_platforms.setdefault(_alias, []).append(_platform) - - -def is_known_platform(platform: str): - # A platform is valid if there is a registered factory for it. It does not - # matter if we were unable to initialize that platform; we only care that - # we've heard of it and it isn't, e.g., a typo. - return platform in _backend_factories.keys() or platform in _platform_aliases.keys() - - -def canonicalize_platform(platform: str) -> str: - """Replaces platform aliases with their concrete equivalent. - - In particular, replaces "gpu" with either "cuda" or "rocm", depending on which - hardware is actually present. We want to distinguish "cuda" and "rocm" for - purposes such as MLIR lowering rules, but in many cases we don't want to - force users to care. - """ - platforms = _alias_to_platforms.get(platform, None) - if platforms is None: - return platform - - b = backends() - for p in platforms: - if p in b.keys(): - return p - raise RuntimeError( - f"Unknown backend: '{platform}' requested, but no " - f"platforms that are instances of {platform} are present. " - "Platforms present are: " + ",".join(b.keys()) - ) - - -def expand_platform_alias(platform: str) -> List[str]: - """Expands, e.g., "gpu" to ["cuda", "rocm"]. - - This is used for convenience reasons: we expect cuda and rocm to act similarly - in many respects since they share most of the same code. - """ - return _alias_to_platforms.get(platform, [platform]) - - -def is_gpu(platform): - return platform in ("cuda", "rocm") - - -def backends(): - global _backends - global _backends_errors - global _default_backend - - with _backend_lock: - if _backends: - return _backends - if config.jax_platforms: - jax_platforms = config.jax_platforms.split(",") - platforms = [] - # Allow platform aliases in the list of platforms. - for platform in jax_platforms: - platforms.extend(expand_platform_alias(platform)) - priorities = range(len(platforms), 0, -1) - platforms_and_priorites = zip(platforms, priorities) - else: - platforms_and_priorites = ( - (platform, priority) - for platform, (_, priority) in _backend_factories.items() - ) - default_priority = -1000 - if hasattr(xla_client, "maybe_load_pjrt_plugins"): - xla_client.maybe_load_pjrt_plugins() - for platform, priority in platforms_and_priorites: - try: - backend = _init_backend(platform) - _backends[platform] = backend - - if priority > default_priority: - _default_backend = backend - default_priority = priority - except Exception as err: - if platform in ("cpu", "interpreter"): - # We always expect the CPU and interpreter backends to initialize - # successfully. - raise - else: - # If the backend isn't built into the binary, or if it has no devices, - # we expect a RuntimeError. - err_msg = f"Unable to initialize backend '{platform}': {err}" - if config.jax_platforms: - err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" - raise RuntimeError(err_msg) - else: - _backends_errors[platform] = str(err) - logger.info(err_msg) - continue - # We don't warn about falling back to CPU on Mac OS, because we don't - # support anything else there at the moment and warning would be pointless. - if ( - py_platform.system() != "Darwin" - and _default_backend.platform == "cpu" - and FLAGS.jax_platform_name != "cpu" - ): - logger.warning( - "No GPU/TPU found, falling back to CPU. " - "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" - ) - return _backends - - -def _clear_backends(): - global _backends - global _backends_errors - global _default_backend - - logger.info("Clearing JAX backend caches.") - with _backend_lock: - _backends = {} - _backends_errors = {} - _default_backend = None - - get_backend.cache_clear() - - -def _init_backend(platform): - factory, unused_priority = _backend_factories.get(platform, (None, None)) - if factory is None: - raise RuntimeError(f"Unknown backend '{platform}'") - - logger.debug("Initializing backend '%s'", platform) - backend = factory() - - if backend is None: - raise RuntimeError(f"Could not initialize backend '{platform}'") - if backend.device_count() == 0: - raise RuntimeError(f"Backend '{platform}' provides no devices.") - logger.debug("Backend '%s' initialized", platform) - return backend - - -def _get_backend_uncached(platform=None): - if not isinstance(platform, (type(None), str)): - return platform - - platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None - - bs = backends() - if platform is not None: - platform = canonicalize_platform(platform) - backend = bs.get(platform, None) - if backend is None: - if platform in _backends_errors: - raise RuntimeError( - f"Backend '{platform}' failed to initialize: " - f"{_backends_errors[platform]}" - ) - raise RuntimeError(f"Unknown backend {platform}") - return backend - else: - return _default_backend - - -@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. -def get_backend(platform=None): - return _get_backend_uncached(platform) - - -def get_device_backend(device=None): - """Returns the Backend associated with `device`, or the default Backend.""" - if device is not None: - return device.client - return get_backend() - - -def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: - """Returns the total number of devices. - - On most platforms, this is the same as :py:func:`jax.local_device_count`. - However, on multi-process platforms where different devices are associated - with different processes, this will return the total number of devices across - all processes. - - Args: - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - - Returns: - Number of devices. - - """ - return int(get_backend(backend).device_count()) - - -def local_device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: - """Returns the number of devices addressable by this process.""" - return int(get_backend(backend).local_device_count()) - - -def devices( - backend: Optional[Union[str, XlaBackend]] = None -) -> List[xla_client.Device]: - """Returns a list of all devices for a given backend. - - .. currentmodule:: jaxlib.xla_extension - - Each device is represented by a subclass of :class:`Device` (e.g. - :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is - equal to ``device_count(backend)``. Local devices can be identified by - comparing :attr:`Device.process_index` to the value returned by - :py:func:`jax.process_index`. - - If ``backend`` is ``None``, returns all the devices from the default backend. - The default backend is generally ``'gpu'`` or ``'tpu'`` if available, - otherwise ``'cpu'``. - - Args: - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - - Returns: - List of Device subclasses. - """ - return get_backend(backend).devices() - - -def default_backend() -> str: - """Returns the platform name of the default XLA backend.""" - return get_backend(None).platform - - -def local_devices( - process_index: Optional[int] = None, - backend: Optional[Union[str, XlaBackend]] = None, - host_id: Optional[int] = None, -) -> List[xla_client.Device]: - """Like :py:func:`jax.devices`, but only returns devices local to a given process. - - If ``process_index`` is ``None``, returns devices local to this process. - - Args: - process_index: the integer index of the process. Process indices can be - retrieved via ``len(jax.process_count())``. - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - - Returns: - List of Device subclasses. - """ - if host_id is not None: - warnings.warn( - "The argument to jax.local_devices has been renamed from `host_id` to " - "`process_index`. This alias will eventually be removed; please update " - "your code." - ) - process_index = host_id - if process_index is None: - process_index = get_backend(backend).process_index() - if not (0 <= process_index < process_count()): - raise ValueError(f"Unknown process_index {process_index}") - return [d for d in devices(backend) if d.process_index == process_index] - - -def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int: - """Returns the integer process index of this process. - - On most platforms, this will always be 0. This will vary on multi-process - platforms though. - - Args: - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - - Returns: - Integer process index. - """ - return get_backend(backend).process_index() - -# returns the number of mge processes associated with the backend -def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: - return max(d.process_index for d in devices(backend)) + 1 diff --git a/imperative/python/megengine/xla/lower.py b/imperative/python/megengine/xla/lower.py deleted file mode 100644 index 17c70739c..000000000 --- a/imperative/python/megengine/xla/lower.py +++ /dev/null @@ -1,260 +0,0 @@ -import dataclasses -import itertools -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import numpy as np - -from ..core._imperative_rt.core2 import OpInfo, VarInfo -from . import utils -from .device import xb -from .ir_utils import ( - TraceResult, - ir_constant_tuple, - mge_varinfo_to_ir_type_tuple, -) -from .lib import xla_client as xc -from .lib.mlir import dialects, ir -from .lib.mlir.dialects import func as func_dialect -from .rules import get_rule -from .rules.hlotensor import HLOTensor -from .rules.utils import _shape_equal -from .sharding import sharded_val - - -def make_ir_context() -> ir.Context: - context = ir.Context() - dialects.mhlo.register_mhlo_dialect(context) - dialects.chlo.register_dialect(context) - dialects.stablehlo.register_dialect(context) - return context - - -@dataclasses.dataclass -class ModuleContext: - context: ir.Context - module: ir.Module - ip: ir.InsertionPoint - symbol_table: ir.SymbolTable - backend_or_name: Optional[Union[str, xb.XlaBackend]] - platform: str - keepalives: List[Any] - channel_iterator: Iterator[int] - host_callbacks: List[Any] - - # Stores the value of varinfo that can be inferred in lowering process - inferred_values: Dict[VarInfo, np.ndarray] - - def __init__( - self, - backend_or_name: Optional[Union[str, xb.XlaBackend]], - platform: str, - keepalives: List[Any] = [], - host_callbacks: List[Any] = [], - context: Optional[ir.Context] = None, - module: Optional[ir.Module] = None, - ip: Optional[ir.InsertionPoint] = None, - symbol_table: Optional[ir.SymbolTable] = None, - ): - assert platform is not None - self.context = context or make_ir_context() - self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) - self.ip = ip or ir.InsertionPoint(self.module.body) - self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) - self.backend_or_name = backend_or_name - self.platform = platform - self.keepalives = keepalives - self.host_callbacks = host_callbacks - self.inferred_values = {} - - @property - def backend(self) -> xb.XlaBackend: - if self.backend_or_name is None or isinstance(self.backend_or_name, str): - return xb.get_backend(self.backend_or_name) - return self.backend_or_name - - def replace(self, **kw): - return dataclasses.replace(self, **kw) - - def get_value(self, varinfo): - assert varinfo in self.inferred_values - return self.inferred_values[varinfo] - - def set_value(self, varinfo, value): - self.inferred_values[varinfo] = value - - -@dataclasses.dataclass -class LoweringRuleContext: - module_context: ModuleContext - op: OpInfo - vars_in: Sequence[VarInfo] - vars_out: Sequence[VarInfo] - param: Dict = None - - def replace(self, **kw): - return dataclasses.replace(self, **kw) - - -def _unwrap_singleton_ir_values(x): - return x[0] if len(x) == 1 else x - - -def _wrap_singleton_ir_values( - x: Union[ir.Value, Sequence[ir.Value]] -) -> Sequence[ir.Value]: - return (x,) if isinstance(x, ir.Value) else tuple(x) - - -def lowering_ops( - ctx: ModuleContext, trace_result: TraceResult, *args: Sequence[ir.Value], -): - # var_id -> ir.Value - env: Dict[int, Tuple[ir.Value, ...]] = {} - consts = list(map(ir_constant_tuple, trace_result._var_consts)) - - # read ir.Values from env according to var_ids - def read(var_ids): - assert isinstance(var_ids, (list, tuple)) - ret = [] - for vid in var_ids: - assert isinstance(vid, int) - ret.append(env[vid]) - return ret - - # update env with var_ids and ir.Values - def write(var_ids, hlo_nodes): - assert isinstance(var_ids, (list, tuple)) - assert isinstance(hlo_nodes, (map, list, tuple)) - hlo_nodes = list(hlo_nodes) - assert len(var_ids) == len(hlo_nodes), (len(var_ids), len(hlo_nodes)) - for vid, node in zip(var_ids, hlo_nodes): - assert vid not in env - env[vid] = node - - assert len(args) == len(trace_result.inputs) - assert len(consts) == len(trace_result.consts) - assert all(isinstance(v, ir.Value) for vs in consts for v in vs) - - # initialize env with inputs and consts - write(trace_result.inputs, args) - write(trace_result.consts, consts) - - for eqn in trace_result.eqns: - rule_ctx = LoweringRuleContext( - module_context=ctx, - op=eqn.op, - vars_in=[trace_result.vars[inp] for inp in eqn.inputs], - vars_out=[trace_result.vars[oup] for oup in eqn.outputs], - param=eqn.param, - ) - rule = get_rule(eqn.op) - - in_nodes = read(eqn.inputs) - hinps = [ - HLOTensor(irval, var.shape, var.dtype) - for var, irval in zip( - rule_ctx.vars_in, map(_unwrap_singleton_ir_values, in_nodes) - ) - ] - houps = rule(rule_ctx, *hinps) - if isinstance(houps, HLOTensor): - houps = [houps] - - out_nodes = [] - for out_id, hlo_out in zip(eqn.outputs, houps): - var_out = trace_result.vars[out_id] - assert _shape_equal( - var_out.shape, hlo_out.shape - ), f"{eqn.op}: {var_out.shape} != {hlo_out.shape}" - out_nodes.append(hlo_out.tensor) - out_nodes = tuple(map(_wrap_singleton_ir_values, out_nodes)) - write(eqn.outputs, out_nodes) - return read(trace_result.outputs) - - -def make_xla_graph( - ctx: ModuleContext, - name: str, - trace_result: TraceResult, - public: bool = True, - in_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, - out_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, - input_output_aliases: Optional[Sequence[Optional[int]]] = None, -) -> func_dialect.FuncOp: - assert public is True, "do not process the visibitity of function" - assert ( - in_shardings is None and out_shardings is None - ), "sharding when lowering is not supported yet" - assert ( - input_output_aliases is None or input_output_aliases == [] - ), "donated inputs are not supported yet" - - input_types = [ - mge_varinfo_to_ir_type_tuple(trace_result.vars[idx]) - for idx in trace_result.inputs - ] - output_types = [ - mge_varinfo_to_ir_type_tuple(trace_result.vars[idx]) - for idx in trace_result.outputs - ] - - flat_input_types = utils.flatten_list(input_types) - flat_output_types = utils.flatten_list(output_types) - assert len(flat_input_types) == len(trace_result.inputs) - assert len(flat_output_types) == len(trace_result.outputs) - - ftype = ir.FunctionType.get(flat_input_types, flat_output_types) - func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip) - func_op.attributes["sym_visibility"] = ir.StringAttr.get( - "public" if public else "private" - ) - ctx.symbol_table.insert(func_op) - - entry_block = func_op.add_entry_block() - with ir.InsertionPoint(entry_block): - flat_args = entry_block.arguments - unflattened_args = utils.unflatten_list(flat_args, map(len, input_types)) - outs = lowering_ops(ctx, trace_result, *unflattened_args) - flat_oups = utils.flatten_list(outs) - func_dialect.ReturnOp(flat_oups) - return func_op - - -def lower( - trace_result: TraceResult, - backend, - platform, - in_shardings=None, - out_shardings=None, - donated_invars=None, -): - assert donated_invars is None, "donated inputs are not supported yet" - assert trace_result.effects == [], "effect of trace is not supported" - - if in_shardings is not None: - trace_result.inputs = [ - sharded_val(inp, in_sharding) - for inp, in_sharding in zip(trace_result.inputs, in_shardings) - ] - if out_shardings is not None: - trace_result.outputs = [ - sharded_val(outp, out_sharding) - for outp, out_sharding in zip(trace_result.outputs, out_shardings) - ] - - ctx = ModuleContext(backend, platform) - - with ctx.context, ir.Location.unknown(ctx.context): - module_name = trace_result.func_name - ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) - assert trace_result.effects == [], "effect of trace is not supported" - make_xla_graph( - ctx, - "main", - trace_result, - public=True, - in_shardings=None, - out_shardings=None, - input_output_aliases=[], - ) - return ctx.module, ctx.keepalives, ctx.host_callbacks diff --git a/imperative/python/megengine/xla/rules/__init__.py b/imperative/python/megengine/xla/rules/__init__.py deleted file mode 100644 index 0e777bc97..000000000 --- a/imperative/python/megengine/xla/rules/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from . import ( - communicate, - elemwise, - indexing, - math, - nn, - normalize, - random, - reduction, - tensor, - trivial, -) -from .utils import get_rule diff --git a/imperative/python/megengine/xla/rules/communicate.py b/imperative/python/megengine/xla/rules/communicate.py deleted file mode 100644 index 46dcd9ca6..000000000 --- a/imperative/python/megengine/xla/rules/communicate.py +++ /dev/null @@ -1,82 +0,0 @@ -import itertools -from functools import partial -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir import ir -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .tensor import concat, split -from .utils import register_lower_rule - - -@register_lower_rule(mops.ParamPackConcat) -def parampack_concat_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - flattened = [] - for arg, var_in in zip(args[:-1], ctx.vars_in[:-1]): - ishape_1d = (int(np.prod(var_in.shape)),) - flattened.append(arg.reshape(ishape_1d)) - concated = concat(flattened, 0) - return concated - - -@register_lower_rule(mops.ParamPackSplit) -def parampack_split_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - offsets, shapes, var_outs = ctx.op.offsets, ctx.op.shapes, ctx.vars_out - assert (len(offsets) // 2) == len(shapes) == len(var_outs), "error params" - for var_out, shape in zip(var_outs, shapes): - assert tuple(var_out.shape) == tuple(shape), f"{var_out.shape} .vs {shape}" - - sections = [np.prod(shape) for shape in shapes] - for i, section in enumerate(sections): - assert section == offsets[2 * i + 1] - offsets[2 * i], "error offsets" - - pieces = split(args[0], sections, axis=0) - outputs = [piece.reshape(var_out.shape) for piece, var_out in zip(pieces, var_outs)] - return outputs - - -def _all_reduce(reducer, inp, world_size): - def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]): - groups = np.array( - list(itertools.zip_longest(*replica_groups, fillvalue=-1)), dtype=np.int64 - ).T - return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) - - replica_groups = _replica_groups_hlo([[i for i in range(world_size)]]) - hlo_cfgs = {} - - all_reduce_op = hlo.AllReduceOp( - inp.tensor.type, inp.tensor, replica_groups=replica_groups, **hlo_cfgs - ) - scalar_type = ir_utils.make_ir_type_according_meta(tuple(), inp.dtype) - reducer_region = all_reduce_op.regions[0].blocks.append(scalar_type, scalar_type) - with ir.InsertionPoint(reducer_region): - reducer_ret = reducer(*reducer_region.arguments) - hlo.ReturnOp(reducer_ret.results) - return HLOTensor(all_reduce_op.results) - - -all_reduce_sum = partial(_all_reduce, hlo.AddOp) -all_reduce_prod = partial(_all_reduce, hlo.MulOp) -all_reduce_min = partial(_all_reduce, hlo.MinOp) -all_reduce_max = partial(_all_reduce, hlo.MaxOp) - - -@register_lower_rule(mops.CollectiveComm) -def collective_comm_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1, "collective comm only support one input" - if ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_SUM: - ret = all_reduce_sum(args[0], ctx.op.nr_devices) - elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_PROD: - ret = all_reduce_prod(args[0], ctx.op.nr_devices) - elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_MIN: - ret = all_reduce_min(args[0], ctx.op.nr_devices) - elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_MAX: - ret = all_reduce_max(args[0], ctx.op.nr_devices) - else: - assert False, f"not support mode{ctx.op.mode}" - return ret diff --git a/imperative/python/megengine/xla/rules/elemwise.py b/imperative/python/megengine/xla/rules/elemwise.py deleted file mode 100644 index dc8803120..000000000 --- a/imperative/python/megengine/xla/rules/elemwise.py +++ /dev/null @@ -1,303 +0,0 @@ -import math -from functools import partial -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .utils import register_lower_rule - - -def _infer_elemwise_oshape(inp_shapes): - def _infer_binary_elemwise_oshape(lhs_shape, rhs_shape): - if len(lhs_shape) == 0: - return rhs_shape - if len(rhs_shape) == 0: - return lhs_shape - - if np.prod(lhs_shape) == 1 and len(rhs_shape) != 0: - return rhs_shape - if np.prod(rhs_shape) == 1 and len(rhs_shape) != 0: - return lhs_shape - - oshape = [] - if len(lhs_shape) == len(rhs_shape): - for l, r in zip(lhs_shape, rhs_shape): - if l == r: - oshape.append(l) - elif l == 1: - oshape.append(r) - elif r == 1: - oshape.append(l) - else: - assert False, f"infer elemwise shape error: {lhs_shape} {rhs_shape}" - else: - shorter = lhs_shape if len(lhs_shape) < len(rhs_shape) else rhs_shape - longer = lhs_shape if len(lhs_shape) > len(rhs_shape) else rhs_shape - - right_part = longer[-len(shorter) :] - for l, s in zip(right_part, shorter): - assert ( - l == s or s == 1 - ), f"infer elemwise shape error: {lhs_shape} {rhs_shape}" - oshape = longer - - return oshape - - oshape = tuple() - for ishape in inp_shapes: - oshape = _infer_binary_elemwise_oshape(ishape, oshape) - return oshape - - -def _infer_elemwise_odtype(inp_dtypes): - oup_dtype = inp_dtypes[0] - for inp_dtype in inp_dtypes: - assert ( - inp_dtype == oup_dtype - ), f"elemwise inputs has different dtype {inp_dtypes}" - return oup_dtype - - -def _compare(lhs, rhs, mode, comparison_type=None): - """ - mod: can be - 'EQ' (equal-to), - 'NE' (not equal-to), - 'GE' (greater-or-equal-than), - 'GT' (greater-than), - 'LE' (less-or-equal-than), - 'LT' (less-than) - comparision_type: can be 'UNSIGNED', 'SIGNED', 'FLOAT' - """ - lhs = HLOTensor(lhs) if not isinstance(lhs, HLOTensor) else lhs - rhs = HLOTensor(rhs) if not isinstance(rhs, HLOTensor) else rhs - oshape = _infer_elemwise_oshape([lhs.shape, rhs.shape]) - - lhs = lhs.broadcast_to(oshape) - rhs = rhs.broadcast_to(oshape) - - if comparison_type is None: - if lhs.dtype in [np.int64, np.int32, np.int16, np.int8]: - assert rhs.dtype in [np.int64, np.int32, np.int16, np.int8] - comparison_type = "SIGNED" - elif lhs.dtype in [np.uint64, np.uint32, np.uint16, np.uint8]: - assert rhs.dtype in [np.uint64, np.uint32, np.uint16, np.uint8] - comparison_type = "UNSIGNED" - elif lhs.dtype in [np.float64, np.float32, np.float16]: - assert rhs.dtype in [np.float64, np.float32, np.float16] - comparison_type = "FLOAT" - else: - assert False, f"invalid dtype for compare {lhs.dtype} .vs {rhs.dtype}" - - return HLOTensor( - hlo.CompareOp( - lhs.tensor, - rhs.tensor, - hlo.ComparisonDirectionAttr.get(mode), - compare_type=hlo.ComparisonTypeAttr.get(comparison_type), - ).result - ) - - -def _elemwise(hlo_op, inps): - hinps = [HLOTensor(inp) if not isinstance(inp, HLOTensor) else inp for inp in inps] - - ishapes = [inp.shape for inp in hinps] - idtypes = [inp.dtype for inp in hinps] - - oshape = _infer_elemwise_oshape(ishapes) - odtype = _infer_elemwise_odtype(idtypes) - - broadcasted_inps = [hinp.broadcast_to(oshape) for hinp in hinps] - results = hlo_op(*[binp.tensor for binp in broadcasted_inps]).results - assert len(results) == 1, f"elemwise op {hlo_op} should have only one output" - return HLOTensor(results[0], oshape, odtype) - - -def _elemwise_unary(hlo_op, a): - return _elemwise(hlo_op, [a]) - - -def _elemwise_binary(hlo_op, a, b): - return _elemwise(hlo_op, [a, b]) - - -neg = partial(_elemwise_unary, hlo.NegOp) -abs = partial(_elemwise_unary, hlo.AbsOp) -tanh = partial(_elemwise_unary, hlo.TanhOp) -exp = partial(_elemwise_unary, hlo.ExpOp) -sqrt = partial(_elemwise_unary, hlo.SqrtOp) -log = partial(_elemwise_unary, hlo.LogOp) - -add = partial(_elemwise_binary, hlo.AddOp) -sub = partial(_elemwise_binary, hlo.SubtractOp) -mul = partial(_elemwise_binary, hlo.MulOp) -div = partial(_elemwise_binary, hlo.DivOp) -pow = partial(_elemwise_binary, hlo.PowOp) - - -equal = partial(_compare, mode="EQ") -not_equal = partial(_compare, mode="NE") -greater = partial(_compare, mode="GT") -greater_equal = partial(_compare, mode="GE") -less = partial(_compare, mode="LT") -less_equal = partial(_compare, mode="LE") - - -def abs_grad(x, dy): - return (x / abs(x)) * dy - - -def tanh_grad(x, dy): - return (1.0 - tanh(x) ** 2.0) * dy - - -def bitcast(inp, oshape, odtype): - odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype - return HLOTensor( - hlo.BitcastConvertOp( - ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor - ).result - ) - - -def typecvt(inp, odtype): - odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype - return HLOTensor( - hlo.ConvertOp( - ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor - ).result - ) - - -def gelu(inp, approximate: bool = True): - if approximate: - sqrt_2_over_pi = np.sqrt(2.0 / np.pi) - a = inp ** 3.0 - b = 0.044715 * a - c = inp + b - d = sqrt_2_over_pi * c - e = tanh(d) - f = 1.0 + e - g = 0.5 * f - h = inp * g - else: - assert False, "only approximate gelu is supported" - return h - - -def erfcc(inp): - _a = abs(inp) - _b = 0.5 * _a - _c = 1.0 + _b - _d = 1.0 / _c - _e = _d * 0.17087277 - _f = -0.82215223 + _e - _g = _d * _f - _h = 1.48851587 + _g - _i = _d * _h - _j = -1.13520398 + _i - _k = _d * _j - _l = 0.27886807 + _k - _m = _d * _l - _n = -0.18628806 + _m - _o = _d * _n - _p = 0.09678418 + _o - _q = _d * _p - _r = 0.37409196 + _q - _s = _d * _r - _t = 1.00002368 + _s - _u = _d * _t - _v = inp * inp - _w = -_v - _x = _w - 1.26551223 - _y = _x + _u - _z = exp(_y) - _aa = _d * _z - _ab = 1.0 - _aa - _ac = -_ab - - _ad = (inp >= 0.0).astype(inp.dtype) - _ae = (inp < 0.0).astype(inp.dtype) - _af = _ad * _ab - _ag = _ae * _ac - ret = _af + _ag - return ret - - -def gelu_grad(x, dy, approximate: bool = True): - if approximate: - _a = x * x - _b = -0.5 * _a - _c = exp(_b) - phi = 0.3989422804014327 * _c - _d = x / math.sqrt(2.0) - _e = erfcc(_d) - _f = 1.0 + _e - normcdf_v = 0.5 * _f - _g = x * phi - _h = normcdf_v + _g - ret = dy * _h - else: - assert False - return ret - - -def relu(inp): - mask = (inp > 0.0).astype(inp.dtype) - return inp * mask - - -def relu_grad(x, dy): - mask = (x > 0.0).astype(x.dtype) - return dy * mask - - -# Elemwise.Mode is unhashable, so we convert it to str -mge_elemwise_to_xla = { - str(mops.Elemwise.Mode.ADD): add, - str(mops.Elemwise.Mode.MUL): mul, - str(mops.Elemwise.Mode.SUB): sub, - str(mops.Elemwise.Mode.EXP): exp, - str(mops.Elemwise.Mode.LOG): log, - str(mops.Elemwise.Mode.GELU): gelu, - str(mops.Elemwise.Mode.GELU_GRAD): gelu_grad, - str(mops.Elemwise.Mode.TRUE_DIV): div, - str(mops.Elemwise.Mode.NEGATE): neg, - str(mops.Elemwise.Mode.ABS): abs, - str(mops.Elemwise.Mode.ABS_GRAD): abs_grad, - str(mops.Elemwise.Mode.TANH): tanh, - str(mops.Elemwise.Mode.TANH_GRAD): tanh_grad, - str(mops.Elemwise.Mode.SQRT): sqrt, - str(mops.Elemwise.Mode.POW): pow, - str(mops.Elemwise.Mode.RELU): relu, - str(mops.Elemwise.Mode.EQ): equal, - str(mops.Elemwise.Mode.NEQ): not_equal, - str(mops.Elemwise.Mode.LT): less, - str(mops.Elemwise.Mode.LEQ): less_equal, - str(mops.Elemwise.Mode.SWITCH_GT0): relu_grad, -} - - -@register_lower_rule(mops.Elemwise) -def elemwise_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert isinstance(ctx.op, mops.Elemwise), "op should be elemwise here" - assert ( - len(ctx.vars_out) == 1 - ), f"elemwise output num should be 1, got {len(ctx.vars_out)}" - handle = mge_elemwise_to_xla[str(ctx.op.mode)] - oup = handle(*args) - return oup - - -@register_lower_rule(mops.ElemwiseMultiType) -def elemwise_multi_type_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - opr = ctx.op - mode = "Elemwise.Mode." + str(opr.mode).split(".")[-1] - handle = mge_elemwise_to_xla[mode] - oup = handle(*args).astype(opr.dtype) - return oup diff --git a/imperative/python/megengine/xla/rules/hlotensor.py b/imperative/python/megengine/xla/rules/hlotensor.py deleted file mode 100644 index 663d8e526..000000000 --- a/imperative/python/megengine/xla/rules/hlotensor.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Sequence - -import numpy as np - -from .. import ir_utils -from ..ir_utils import get_irnode_dtype, get_irnode_shape -from ..lib.mlir import ir -from .utils import _check_dtype, _check_shape - - -class HLOTensor: - def __init__(self, tensor, shape=None, dtype=None) -> None: - if isinstance(tensor, Sequence): - assert len(tensor) > 0, "cannot create HLOTensor from empty sequence" - if isinstance(tensor[0], int): - tensor = np.array(tensor) - else: - assert len(tensor) == 1, f"cannot create HLOTensor from {tensor}" - tensor = tensor[0] - if isinstance(tensor, ir.OpResultList): - assert len(tensor) == 1, f"cannot create HLOTensor from {tensor}" - tensor = tensor[0] - - if isinstance( - tensor, (int, float, np.int_, np.float16, np.float32, np.float64) - ): - tensor = ir_utils.ir_constant(tensor) - elif isinstance(tensor, np.ndarray): - tensor = ir_utils.ir_constant(tensor) - else: - pass - - assert isinstance( - tensor, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult) - ), type(tensor) - infered_shape = get_irnode_shape(tensor) - infered_dtype = get_irnode_dtype(tensor) - - _check_shape(infered_shape, shape) - _check_dtype(infered_dtype, dtype) - - self._tensor = tensor - self._shape = infered_shape - self._dtype = infered_dtype - - @property - def shape(self): - return tuple(self._shape) - - @property - def dtype(self): - return self._dtype - - @property - def ndim(self): - return len(self.shape) - - @property - def tensor(self): - return self._tensor - - def __str__(self): - return f"HLOTensor(shape={self.shape}, dtype={self.dtype})" - - def __eq__(self, rhs): - from .elemwise import equal - - return equal(self, rhs) - - def __ne__(self, rhs): - from .elemwise import not_equal - - return not_equal(self, rhs) - - def __gt__(self, rhs): - from .elemwise import greater - - return greater(self, rhs) - - def __ge__(self, rhs): - from .elemwise import greater_equal - - return greater_equal(self, rhs) - - def __lt__(self, rhs): - from .elemwise import less - - return less(self, rhs) - - def __le__(self, rhs): - from .elemwise import less_equal - - return less_equal(self, rhs) - - def __neg__(self): - from .elemwise import neg - - return neg(self) - - def __add__(self, rhs): - from .elemwise import add - - return add(self, rhs) - - def __radd__(self, rhs): - from .elemwise import add - - return add(rhs, self) - - def __sub__(self, rhs): - from .elemwise import sub - - return sub(self, rhs) - - def __rsub__(self, rhs): - from .elemwise import sub - - return sub(rhs, self) - - def __mul__(self, rhs): - from .elemwise import mul - - return mul(self, rhs) - - def __rmul__(self, rhs): - from .elemwise import mul - - return mul(rhs, self) - - def __truediv__(self, rhs): - from .elemwise import div - - return div(self, rhs) - - def __rtruediv__(self, rhs): - from .elemwise import div - - return div(rhs, self) - - def __pow__(self, rhs): - from .elemwise import pow - - return pow(self, rhs) - - def reshape(self, shape): - from .tensor import reshape - - return reshape(self, shape) - - def transpose(self, permutation): - from .tensor import transpose - - return transpose(self, permutation) - - def broadcast_to(self, shape, broadcast_dims=None): - from .tensor import broadcast_to - - return broadcast_to(self, shape, broadcast_dims) - - def bitcast(self, shape, dtype): - from .elemwise import bitcast - - return bitcast(self, shape, dtype) - - def astype(self, dtype): - from .elemwise import typecvt - - return typecvt(self, dtype) - - def sum(self, axis, keepdims=False): - from .reduction import sum - - return sum(self, axis, keepdims) - - def mean(self, axis, keepdims=False): - from .reduction import mean - - return mean(self, axis, keepdims) - - def prod(self, axis, keepdims=False): - from .reduction import prod - - return prod(self, axis, keepdims) - - def max(self, axis, keepdims=False): - from .reduction import max - - return max(self, axis, keepdims) - - def min(self, axis, keepdims=False): - from .reduction import min - - return min(self, axis, keepdims) - - def all(self, axis, keepdims=False): - from .reduction import all - - return all(self, axis, keepdims) - - def any(self, axis, keepdims=False): - from .reduction import any - - return any(self, axis, keepdims) diff --git a/imperative/python/megengine/xla/rules/indexing.py b/imperative/python/megengine/xla/rules/indexing.py deleted file mode 100644 index f32edb190..000000000 --- a/imperative/python/megengine/xla/rules/indexing.py +++ /dev/null @@ -1,696 +0,0 @@ -from collections import namedtuple -from enum import IntEnum -from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir import ir -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .utils import _parse_var_as_value, register_lower_rule - - -""" -case1: idx is a int - x[1] -module @jit_index { - func.func public @main(%arg0: tensor<16x128x224x224xf32> {mhlo.sharding = ""}) -> tensor<128x224x224xf32> { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - %2 = mhlo.constant dense<0> : tensor - %3 = mhlo.constant dense<0> : tensor - %4 = "mhlo.dynamic_slice"(%arg0, %0, %1, %2, %3) {slice_sizes = dense<[1, 128, 224, 224]> : tensor<4xi64>} : (tensor<16x128x224x224xf32>, tensor, tensor, tensor, tensor) -> tensor<1x128x224x224xf32> - %5 = mhlo.reshape %4 : (tensor<1x128x224x224xf32>) -> tensor<128x224x224xf32> - return %5 : tensor<128x224x224xf32> - } -} - -case2: idx is a slice with step is 1 - x[1:10:1] -module @jit_index { - func.func public @main(%arg0: tensor<16x128x224x224xf32> {mhlo.sharding = ""}) -> tensor<9x128x224x224xf32> { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - %2 = mhlo.constant dense<0> : tensor - %3 = mhlo.constant dense<0> : tensor - %4 = "mhlo.dynamic_slice"(%arg0, %0, %1, %2, %3) {slice_sizes = dense<[9, 128, 224, 224]> : tensor<4xi64>} : (tensor<16x128x224x224xf32>, tensor, tensor, tensor, tensor) -> tensor<9x128x224x224xf32> - return %4 : tensor<9x128x224x224xf32> - } -} - -case3: idx is a slice with step is not 1 - x[1:10:2] -module @jit_index { - func.func public @main(%arg0: tensor<16x128x224x224xf32> {mhlo.sharding = ""}) -> tensor<5x128x224x224xf32> { - %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[10, 128, 224, 224]> : tensor<4xi64>, start_indices = dense<[1, 0, 0, 0]> : tensor<4xi64>, strides = dense<[2, 1, 1, 1]> : tensor<4xi64>} : (tensor<16x128x224x224xf32>) -> tensor<5x128x224x224xf32> - return %0 : tensor<5x128x224x224xf32> - } -} - -case4: do index in multi dims - x[1, 10:20, 20:30, 3] -module @jit_index { - func.func public @main(%arg0: tensor<16x128x224x224xf32> {mhlo.sharding = ""}) -> tensor<10x10xf32> { - %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1xi32> - %2 = mhlo.constant dense<10> : tensor - %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1xi32> - %4 = mhlo.constant dense<20> : tensor - %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1xi32> - %6 = mhlo.constant dense<3> : tensor - %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1xi32> - %8 = "mhlo.concatenate"(%1, %3, %5, %7) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> - %9 = "mhlo.gather"(%arg0, %8) {dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[1, 10, 10, 1]> : tensor<4xi64>} : (tensor<16x128x224x224xf32>, tensor<4xi32>) -> tensor<10x10xf32> - return %9 : tensor<10x10xf32> - } -} - -case5: x[:, 1] -module @jit_index { - func.func public @main(%arg0: tensor<16x128x224x224xf32> {mhlo.sharding = ""}) -> tensor<16x224x224xf32> { - %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1xi32> - %2 = "mhlo.gather"(%arg0, %1) {dimension_numbers = #mhlo.gather, indices_are_sorted = true, slice_sizes = dense<[16, 1, 224, 224]> : tensor<4xi64>} : (tensor<16x128x224x224xf32>, tensor<1xi32>) -> tensor<16x224x224xf32> - return %2 : tensor<16x224x224xf32> - } -} - -x[:] -return x directly -""" - - -def _is_canonicalized_axis(sl: slice, axis_len: int): - return ( - (0 <= sl.start and sl.start < axis_len) - and (0 <= sl.stop and sl.stop <= axis_len) - and (0 < sl.step) - ) - - -def _canonicalize_slice_with_axis_len(sl: slice, axis_len: int): - """ - make slice canonicalized: 0 <= sl.start < axis_len and 0 <= sl.stop <= axis_len - """ - - def impl(idx, axis_len): - if idx < 0: - idx = idx + axis_len - assert idx >= 0 and idx <= axis_len, f"{idx}, {axis_len}" - if idx < 0: - idx = 0 - if idx > axis_len: - idx = axis_len - return idx - - assert isinstance(sl, slice) - start = impl(sl.start, axis_len) - stop = impl(sl.stop, axis_len) - - new_sl = slice(start, stop, sl.step) - - assert new_sl.step > 0, "step <= 0 is not supported now" - assert _is_canonicalized_axis( - new_sl, axis_len - ), f"slice {new_sl} is illegal for axis whose length is {axis_len}" - return new_sl - - -def _hslice_with_step_is_one(inp, slices): - """ - if inp_shape is N-dim, slices should contain N slice, slice can not None. - for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)]. - the step of slice should must be 1 - """ - assert all([sl.step == 1 for sl in slices]) - starts = [int(sl.start) for sl in slices] - slice_sizes = [int(max(0, sl.stop - sl.start)) for sl in slices] - - starts = [ir_utils.ir_constant(si) for si in starts] - slice_sizes = ir_utils.dense_int_elements(slice_sizes) - - return hlo.DynamicSliceOp(inp, starts, slice_sizes).results - - -def _hslice_with_any_step(inp, slices): - """ - if inp_shape is N-dim, slices should contain N slice, slice can not None - for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] - """ - starts = [int(sl.start) for sl in slices] - stops = [int(sl.stop) for sl in slices] - steps = [int(sl.step) for sl in slices] - - return hlo.SliceOp( - inp, - ir_utils.dense_int_elements(starts), - ir_utils.dense_int_elements(stops), - ir_utils.dense_int_elements(steps), - ).results - - -def index_with_slices(inp, slices): - """ - if inp_shape is N-dim, slices should contain N slice, slice can be None - for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None] - """ - assert isinstance(slices, Sequence), f"{slices}" - assert len(inp.shape) >= len(slices), f"{inp.shape}, {slices}" - slices = list(slices) + [None,] * (len(inp.shape) - len(slices)) - - slices = [ - sl if sl is not None else slice(0, axis_len, 1) - for (sl, axis_len) in zip(slices, inp.shape) - ] - slices = [ - _canonicalize_slice_with_axis_len(sl, axis_len) - for (sl, axis_len) in zip(slices, inp.shape) - ] - - all_step_is_one = all(sl.step == 1 for sl in slices) - if all_step_is_one: - return HLOTensor(_hslice_with_step_is_one(inp.tensor, slices)) - else: - return HLOTensor(_hslice_with_any_step(inp.tensor, slices)) - - -class IndexType(IntEnum): - DEFAULT = (0,) - INT = (1,) - SLICE = (2,) - NONE = (3,) - ELLIPSIS = (4,) - - -def _parse_subtensor_items_as_slices(srcitems, inp_shape, idx_vars): - inp_ndim = len(inp_shape) - - Item = namedtuple("Item", "axis axis_len slice_or_idx") - items = [] - inp_offset = 0 - for item in srcitems: - # items for: axis, start, step, end, is_index - axis, has_start, has_stop, has_step, is_idx = item - axis_len = inp_shape[axis] - is_slice = has_start or has_stop or has_step - assert is_slice ^ is_idx, f"cannot specify index idx and slice simultaneously" - - if is_slice: - start, stop, step = 0, axis_len, 1 - if has_start: - start = _parse_var_as_value(idx_vars[inp_offset]) - inp_offset += 1 - if has_stop: - stop = _parse_var_as_value(idx_vars[inp_offset]) - inp_offset += 1 - if has_step: - step = _parse_var_as_value(idx_vars[inp_offset]) - inp_offset += 1 - sl = _canonicalize_slice_with_axis_len(slice(start, stop, step), axis_len) - items.append(Item(axis, axis_len, sl)) - elif is_idx: - idx = _parse_var_as_value(idx_vars[inp_offset]) - inp_offset += 1 - if idx < 0: - idx = idx + axis_len - assert ( - 0 <= idx and idx < axis_len - ), f"idx {idx} out of range, shape {inp_shape}, axis {axis}" - items.append(Item(axis, axis_len, idx)) - else: - assert False - - slices = [None,] * inp_ndim - indices_type = [IndexType.DEFAULT,] * inp_ndim - for item in items: - # if item.slice_or_idx is int, that means it is a index, not a slice, so we need to reshape the result - if isinstance(item.slice_or_idx, int): - slices[item.axis] = slice(item.slice_or_idx, item.slice_or_idx + 1, 1) - indices_type[item.axis] = IndexType.INT - else: - slices[item.axis] = item.slice_or_idx - indices_type[item.axis] = IndexType.SLICE - return ( - slices, - indices_type, - any([isinstance(item.slice_or_idx, int) for item in items]), - ) - - -@register_lower_rule(mops.Subtensor) -def subtensor_lower( - ctx, *args: Union[ir.Value, Sequence[ir.Value]], explicit_type=False -): - assert len(ctx.op.slice_items) == 0 and len(ctx.vars_out) == 1 - opr, inp, inp_shape = ctx.op, args[0], ctx.vars_in[0].shape - slices, _, any_axis_is_index = _parse_subtensor_items_as_slices( - opr.items, inp_shape, ctx.vars_in[1:] - ) - oup = index_with_slices(inp, slices) - - if any_axis_is_index: - return oup.reshape(ctx.vars_out[0].shape) - else: - return oup - - -""" -# x.shape = (32, 16, 8), y.shape = (8,): x[10, 13] = y -_Indexer( - slice_shape=[8], - gather_slice_shape=[1, 1, 8], - gather_indices=Array([10, 13], dtype=int32), - dnums=GatherDimensionNumbers( - offset_dims=(0,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1) - ), - unique_indices=True, - indices_are_sorted=True, - reversed_y_dims=[], - newaxis_dims=(), -) - -# x.shape = (32, 16, 8), y.shape = (8,): x[10, slice(0, 3, 2)] = y -_Indexer( - slice_shape=[2, 8], - gather_slice_shape=[1, 1, 8], - gather_indices=Array([[10, 0], [10, 2]], dtype=int32), - dnums=GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1) - ), - unique_indices=True, - indices_are_sorted=True, - reversed_y_dims=[], - newaxis_dims=(), -) - -# x.shape = (32, 16, 8), y.shape = (2, 8,): x[10, slice(0, 3, 2)] = y -_Indexer( - slice_shape=[2, 8], - gather_slice_shape=[1, 1, 8], - gather_indices=Array([[10, 0], [10, 2]], dtype=int32), - dnums=GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1) - ), - unique_indices=True, - indices_are_sorted=True, - reversed_y_dims=[], - newaxis_dims=(), -) - - -# x.shape = (32, 16, 8), y.shape = (1,): x[10, slice(0, 3, 2)] = y -_Indexer( - slice_shape=[2, 8], - gather_slice_shape=[1, 1, 8], - gather_indices=Array([[10, 0], [10, 2]], dtype=int32), - dnums=GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1) - ), - unique_indices=True, - indices_are_sorted=True, - reversed_y_dims=[], - newaxis_dims=(), -) - -# x.shape = (32, 16, 8), y.shape = (32, 2, 8): x[:, slice(0, 3, 2)] = y -_Indexer( - slice_shape=[32, 2, 8], - gather_slice_shape=[32, 1, 8], - gather_indices=Array([[0], [2]], dtype=int32), - dnums=GatherDimensionNumbers( - offset_dims=(0, 2), collapsed_slice_dims=(1,), start_index_map=(1,) - ), - unique_indices=True, - indices_are_sorted=True, - reversed_y_dims=[], - newaxis_dims=(), -) -""" - - -class GatherDimensionNumbers(NamedTuple): - offset_dims: Tuple[int, ...] - collapsed_slice_dims: Tuple[int, ...] - start_index_map: Tuple[int, ...] - - -class _Indexer(NamedTuple): - # The expected shape of the slice output. - slice_shape: Sequence[int] - # The slice shape to pass to lax.gather(). - gather_slice_shape: Sequence[int] - # The gather indices to use. - gather_indices: Any - # A GatherDimensionNumbers object describing the gather to perform. - dnums: GatherDimensionNumbers - - # Are the gather_indices known to be non-overlapping and/or sorted? - # (In practice, these translate to "there no advanced indices", because - # only advanced indices could lead to index repetition.) - unique_indices: bool - indices_are_sorted: bool - - # Slice dimensions that have negative strides, and so must be reversed after - # the gather. - reversed_y_dims: Sequence[int] - - # Keep track of any axes created by `newaxis`. These must be inserted for - # gathers and eliminated for scatters. - newaxis_dims: Sequence[int] - - -class ScatterDimensionNumbers(NamedTuple): - update_window_dims: Sequence[int] - inserted_window_dims: Sequence[int] - scatter_dims_to_operand_dims: Sequence[int] - - -def _static_idx(idx: slice, size): - if isinstance(size, int): - start, stop, step = idx.indices(size) - else: - raise TypeError(size) - - if (step < 0 and stop >= start) or (step > 0 and start >= stop): - return 0, 0, 1, False # sliced to size zero - - if step > 0: - return start, stop, step, False - else: - k = (start - stop - 1) % (-step) - return stop + k + 1, start + 1, -step, True - - -def _index_to_gather( - x_shape, indices, indices_type, normalize_indices: bool = True -) -> _Indexer: - assert len(indices) == len(indices_type), f"{len(indices)}, {len(indices_type)}" - assert len(indices) == len(x_shape), f"{len(indices)}, {len(x_shape)} " - - advanced_indexes: Optional[Sequence[Union[Sequence, np.ndarray]]] = None - x_axis = 0 # Current axis in x. - y_axis = 0 # Current axis in y, before collapsing. See below. - collapsed_y_axis = 0 # Current axis in y, after collapsing. - - # Scatter dimension numbers. - offset_dims: Sequence[int] = [] - collapsed_slice_dims: Sequence[int] = [] - start_index_map: Sequence[int] = [] - index_dtype = np.int32 - - # Gather indices. - # Pairs of (array, start_dim) values. These will be broadcast into - # gather_indices_shape, with the array dimensions aligned to start_dim, and - # then concatenated. - gather_indices: List[Tuple[Sequence, int]] = [] - gather_indices_shape: List[int] = [] - - # We perform three transformations to y before the scatter op, in order: - # First, y is broadcast to slice_shape. In general `y` only need broadcast to - # the right shape. - slice_shape: Sequence[int] = [] - - # Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None` - # indices, which the scatter cannot remove itself. - newaxis_dims: Sequence[int] = [] - - # Finally, we reverse reversed_y_dims to handle slices with negative strides. - reversed_y_dims: Sequence[int] = [] - - gather_slice_shape: Sequence[int] = [] - - for i, (idx, idx_type) in enumerate(zip(indices, indices_type)): - if idx is None: - assert idx_type == IndexType.DEFAULT - indices_type[i] = IndexType.SLICE - indices[i] = slice(None, None, None) - - for idx, idx_type in zip(indices, indices_type): - # Handle basic int indexes. - if idx_type == IndexType.INT: - gather_indices.append( - (np.array(idx.start, index_dtype), len(gather_indices_shape)) - ) - collapsed_slice_dims.append(x_axis) - gather_slice_shape.append(1) - start_index_map.append(x_axis) - x_axis += 1 - # # Handle np.newaxis (None) - # elif idx_type == IndexType.NONE: - # slice_shape.append(1) - # newaxis_dims.append(y_axis) - # y_axis += 1 - elif idx_type == IndexType.SLICE: - # Normalize the slice to use None when possible - start, stop, step = idx.start, idx.stop, idx.step - # Handle slice(None) and slice(None, None, -1) - if ( - start is None - and stop is None - and (step is None or isinstance(step, int) and step == -1) - ): - if step == -1: - reversed_y_dims.append(collapsed_y_axis) - slice_shape.append(x_shape[x_axis]) - gather_slice_shape.append(x_shape[x_axis]) - offset_dims.append(collapsed_y_axis) - collapsed_y_axis += 1 - y_axis += 1 - x_axis += 1 - # Handle slice index (only static, otherwise an error is raised) - else: - start, limit, stride, needs_rev = _static_idx( - slice(start, stop, step), x_shape[x_axis] - ) - if needs_rev: - reversed_y_dims.append(collapsed_y_axis) - if stride == 1: - idx = np.array(start, index_dtype) - gather_indices.append((idx, len(gather_indices_shape))) - slice_shape.append(limit - start) - gather_slice_shape.append(limit - start) - offset_dims.append(collapsed_y_axis) - start_index_map.append(x_axis) - else: - idx = np.arange(start, limit, stride, dtype=index_dtype) - size = idx.shape[0] - slice_shape.append(size) - gather_slice_shape.append(1) - gather_indices.append((idx, len(gather_indices_shape))) - gather_indices_shape.append(size) - - start_index_map.append(x_axis) - collapsed_slice_dims.append(x_axis) - - collapsed_y_axis += 1 - y_axis += 1 - x_axis += 1 - else: - msg = "Indexing mode not yet supported. Open a feature request!\n{}" - raise IndexError(msg.format(indices)) - - if len(gather_indices) == 0: - gather_indices_array = np.zeros((0,), dtype=index_dtype) - elif len(gather_indices) == 1: - g, _ = gather_indices[0] - gather_indices_array = np.expand_dims(g, (g.ndim,)) - else: - last_dim = len(gather_indices_shape) - gather_indices_shape.append(1) - - def _broadcast_to(src, tgt_shape, axises): - src_shape = src.shape - expanded_src_shape = [1,] * len(tgt_shape) - for i, ax in enumerate(axises): - expanded_src_shape[ax] = src_shape[i] - src = np.reshape(src, expanded_src_shape) - return np.broadcast_to(src, tgt_shape) - - gather_indices_array = np.concatenate( - [ - _broadcast_to(g, gather_indices_shape, tuple(range(i, i + g.ndim))) - for g, i in gather_indices - ], - last_dim, - ) - - dnums = GatherDimensionNumbers( - offset_dims=tuple(offset_dims), - collapsed_slice_dims=tuple(sorted(collapsed_slice_dims)), - start_index_map=tuple(start_index_map), - ) - return _Indexer( - slice_shape=slice_shape, - newaxis_dims=tuple(newaxis_dims), - gather_slice_shape=gather_slice_shape, - reversed_y_dims=reversed_y_dims, - dnums=dnums, - gather_indices=gather_indices_array, - unique_indices=advanced_indexes is None, - indices_are_sorted=advanced_indexes is None, - ) - - -def scatter( - x, - indices, - y, - dnums, - oup_var=None, - indices_are_sorted=False, - unique_indices=False, - mode=None, -): - scatter_dnums = hlo.ScatterDimensionNumbers.get( - update_window_dims=list(dnums.update_window_dims), - inserted_window_dims=list(dnums.inserted_window_dims), - scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(indices.shape) - 1, - ) - if oup_var is not None: - oshape, odtype = oup_var.shape, oup_var.dtype - else: - oshape, odtype = x.shape, x.dtype - - op = hlo.ScatterOp( - ir_utils.make_ir_type_according_meta_tuple(oshape, odtype), - [x.tensor], - ir_utils.ir_constant(indices), - [y.tensor], - scatter_dnums, - indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), - unique_indices=ir.BoolAttr.get(unique_indices), - ) - - scalar_type = ir_utils.make_ir_type_according_meta(tuple(), odtype) - update = op.update_computation.blocks.append(scalar_type, scalar_type) - - with ir.InsertionPoint(update): - hlo.ReturnOp((update.arguments[1],)) - return HLOTensor(op.results) - - -@register_lower_rule(mops.SetSubtensor) -def setsubtensor_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(ctx.vars_out) == 1 - opr, x, y = ctx.op, args[0], args[1] - - slices, indices_type, _ = _parse_subtensor_items_as_slices( - opr.items, x.shape, ctx.vars_in[2:] - ) - - indexer = _index_to_gather(x.shape, slices, indices_type) - if len(indexer.slice_shape) == 0 or np.prod(indexer.slice_shape) == 0: - return [x] - - y = y.broadcast_to(indexer.slice_shape) - if len(indexer.newaxis_dims) != 0: - assert False, "not support" - if len(indexer.reversed_y_dims) != 0: - assert False, "not support" - - dnums = ScatterDimensionNumbers( - update_window_dims=indexer.dnums.offset_dims, - inserted_window_dims=indexer.dnums.collapsed_slice_dims, - scatter_dims_to_operand_dims=indexer.dnums.start_index_map, - ) - - out = scatter( - x, - indexer.gather_indices, - y, - dnums, - indices_are_sorted=indexer.indices_are_sorted, - unique_indices=indexer.unique_indices, - mode=None, - ) - return out - - -def _check_tensor_indexing_arg(src, index, axis): - assert src.ndim - 1 == index.ndim, f"{src.shape} {index.shape}" - assert axis < src.ndim and 0 <= axis, f"invalid axis {axis} for shape {src.shape}" - - src_shape = list(src.shape) - del src_shape[axis] - assert src_shape == list(index.shape), f"{src.shape} {index.shape} {axis}" - - assert str(index.dtype) in [ - "int16", - "int32", - "int64", - "uint16", - "uint32", - "uint64", - ], f"{index.dtype}" - - -def indexing_with_tensor_index(src, index, axis=-1, keepdims=False): - """ - indexing select items from src according to index in one dimension. - src.ndim should equal to index.ndim + 1. - if the src.shape remove the axis-th element, it should equal to index.shape. - for example: - src.shape=(2, 4, 6), index.shape=(2, 4), axis=2, out.shape=(2, 4, 1), out[i, j, 1] = src[i, j, index[i, j]] - src.shape=(2, 4, 6), index.shape=(2, 6), axis=1, out.shape=(2, 1, 6), out[i, 1, j] = src[i, index[i, j], j] - src.shape=(3, 9), index.shape=(3,), axis=1, out.shape=(3, 1), out[i, 1] = src[i, index[i]] - src.shape=(3, 9), index.shape=(9,), axis=0, out.shape=(1, 9), out[1, i] = src[index[i], i] - """ - axis = (axis + src.ndim) if axis < 0 else axis - _check_tensor_indexing_arg(src, index, axis) - - arange_array = np.arange(src.shape[axis], dtype=index.dtype) - arange_array = HLOTensor(arange_array).broadcast_to( - src.shape, broadcast_dims=[axis] - ) - broadcast_dims = [i for i in range(src.ndim) if i != axis] - index_array = index.broadcast_to(src.shape, broadcast_dims=broadcast_dims) - - mask = (arange_array == index_array).astype(src.dtype) - return (src * mask).sum(axis, keepdims=keepdims) - - -@register_lower_rule(mops.IndexingOneHot) -def indexing_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ( - len(ctx.vars_out) == 1 and len(ctx.vars_in) == 2 and len(args) == 2 - ), f"{len(ctx.vars_out)}, {len(ctx.vars_in)}, {len(args)}" - assert ctx.op.ndim == args[0].ndim, f"{ctx.op.ndim}, {args[0].shape}" - return indexing_with_tensor_index(args[0], args[1], ctx.op.axis, keepdims=True) - - -def indexing_set_with_tensor_index(src, value, index, axis): - """ - indexing set value to src according to index in one dimension. - value shape should can be broadcast or reshape to index shape. - if value shape not equal to index shape, it should be broadcast to index shape firstly - examples: - src.shape=(2, 4, 6), value.shape=(2, 4), index.shape=(2, 4), axis=2, out.shape=(2, 4, 6) - out[i, j, k] = value[i, j] if k == index[i, j] else src[i, j, k] - - src.shape=(2, 4, 6), value.shape=(2, 6), index.shape=(2, 6), axis=1, out.shape=(2, 4, 6) - out[i, j, k] = value[i, k] if j == index[i, k] else src[i, j, k] - """ - axis = (axis + src.ndim) if axis < 0 else axis - _check_tensor_indexing_arg(src, index, axis) - - value = value if isinstance(value, HLOTensor) else HLOTensor(value) - assert src.dtype == value.dtype, f"{src.dtype}, {value.dtype}" - - arange_array = np.arange(src.shape[axis]).astype(index.dtype) - arange_array = HLOTensor(arange_array).broadcast_to(src.shape, [axis]) - broadcast_dims = [i for i in range(src.ndim) if i != axis] - index_array = index.broadcast_to(src.shape, broadcast_dims=broadcast_dims) - - mask1 = (arange_array == index_array).astype(src.dtype) - mask2 = (arange_array != index_array).astype(value.dtype) - - return mask1 * value + mask2 * src - - -@register_lower_rule(mops.IndexingSetOneHot) -def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ( - len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3 and len(args) == 3 - ), f"{len(ctx.vars_out)}, {len(ctx.vars_in)}, {len(args)}" - - assert ctx.op.ndim == args[0].ndim, f"{ctx.op.ndim}, {args[0].shape}" - return indexing_set_with_tensor_index(args[0], args[2], args[1], ctx.op.axis) diff --git a/imperative/python/megengine/xla/rules/math.py b/imperative/python/megengine/xla/rules/math.py deleted file mode 100644 index 818498da8..000000000 --- a/imperative/python/megengine/xla/rules/math.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import Sequence, Union - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .utils import _can_broadcast_to, _shape_equal, register_lower_rule - - -@register_lower_rule(mops.Dot) -def dot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ( - len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2 - ), f"{len(ctx.vars_in)}, {len(ctx.vars_out)}, {len(args)}" - assert args[0].ndim == 1 and args[1].ndim == 1, f"{args[0].shape}, {args[1].shape}" - assert args[0].shape[0] == args[1].shape[0], f"{args[0].shape}, {args[1].shape}" - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=tuple(), - rhs_batching_dimensions=tuple(), - lhs_contracting_dimensions=(0,), - rhs_contracting_dimensions=(0,), - ) - - return [ - HLOTensor( - hlo.DotGeneralOp( - ir_utils.make_ir_type_according_meta((), ctx.vars_out[0].dtype), - args[0].tensor, - args[1].tensor, - dot_dnums, - precision_config=ir_utils.precision_attr(args[0].dtype, args[1].dtype), - ).result - ).reshape(ctx.vars_out[0].shape) - ] - - -@register_lower_rule(mops.MatrixMul) -def matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2 - assert ( - ctx.op.compute_mode == mops.BatchedMatrixMul.ComputeMode.DEFAULT - ), f"{ctx.op.compute_mode}" - assert ctx.op.format == mops.BatchedMatrixMul.Format.DEFAULT, f"{ctx.op.format}" - assert ctx.op.dimA == len(args[0].shape) and ctx.op.dimB == len( - args[1].shape - ), f"{ctx.op.dimA}, {ctx.op.dimB}, {args[0].shape}, {args[1].shape}" - assert args[0].ndim >= 2 and args[1].ndim >= 2, f"{args[0].shape}, {args[1].shape}" - lhs, rhs = args[0], args[1] - - # in mge batchmatmul, [a, b, c, d] * [a, b, c, f] -> [a, b, f, d] - # but in mge matmul, dims [:-1] is interpreted as one edge of matrix - # that means [a, b, c, d] * [a, b, c, f] -> [a*b*c, d] * [a*b*c, f] -> [f, d] - if lhs.ndim > 2 and rhs.ndim > 2: - lhs = lhs.reshape(shape=(-1, lhs.shape[-1])) - rhs = rhs.reshape(shape=(-1, rhs.shape[-1])) - - lhs_reduce_axis = lhs.ndim - 2 if ctx.op.transposeA else lhs.ndim - 1 - rhs_reduce_axis = rhs.ndim - 1 if ctx.op.transposeB else rhs.ndim - 2 - assert ( - lhs.shape[lhs_reduce_axis] == rhs.shape[rhs_reduce_axis] - ), f"reduce axis length mismatch: {lhs.shape}, {rhs.shape}, {lhs_reduce_axis}, {rhs_reduce_axis}" - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=tuple(), - rhs_batching_dimensions=tuple(), - lhs_contracting_dimensions=(lhs_reduce_axis,), - rhs_contracting_dimensions=(rhs_reduce_axis,), - ) - - return [ - HLOTensor( - hlo.DotGeneralOp( - ir_utils.mge_varinfo_to_ir_type(ctx.vars_out[0]), - lhs.tensor, - rhs.tensor, - dot_dnums, - precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype), - ).result - ) - ] - - -def _bmm_shape_helper(lhs_shape, rhs_shape, lhs_transpose, rhs_transpose): - lhs_reduce_axis = len(lhs_shape) - 2 if lhs_transpose else len(lhs_shape) - 1 - rhs_reduce_axis = len(rhs_shape) - 1 if rhs_transpose else len(rhs_shape) - 2 - - # get the shape of inputs after transpose - lhs_shape, rhs_shape = list(lhs_shape), list(rhs_shape) - if lhs_transpose: - lhs_shape[-2], lhs_shape[-1] = lhs_shape[-1], lhs_shape[-2] - if rhs_transpose: - rhs_shape[-2], rhs_shape[-1] = rhs_shape[-1], rhs_shape[-2] - # get the batch info of inputs - lhs_prefix, rhs_prefix = lhs_shape[:-2], rhs_shape[:-2] - - # only the batch of input_a can broadcast to input_b supported - assert _can_broadcast_to(lhs_prefix, rhs_prefix) or _can_broadcast_to( - rhs_prefix, lhs_prefix - ), f"{lhs_shape}, {rhs_shape}" - - # get the batch axis of input_a and input_b, for example: - # (3, 4, 5) * (3, 5, 6), the batch axis is (0,) and (0,) - # (3, 4, 5) * (2, 3, 5, 6), the batch axis is (0,) and (1,) - # (2, 3, 4, 5) * (2, 3, 5, 6), the batch axis is (0, 1) and (0, 1) - lhs_batch_axis, rhs_batch_axis = [], [] - min_len = min(len(lhs_shape), len(rhs_shape)) - for i in range(-3, -min_len - 1, -1): - if lhs_shape[i] == rhs_shape[i]: - lhs_batch_axis.append(i) - rhs_batch_axis.append(i) - - elif lhs_shape[i] == 1 or rhs_shape[i] == 1: - lhs_batch_axis.append(i) - rhs_batch_axis.append(i) - - else: - break - - lhs_batch_axis = [val + len(lhs_shape) for val in lhs_batch_axis] - rhs_batch_axis = [val + len(rhs_shape) for val in rhs_batch_axis] - lhs_batch_axis.sort() - rhs_batch_axis.sort() - - assert len(lhs_batch_axis) == len(lhs_prefix) or len(rhs_batch_axis) == len( - rhs_prefix - ), f"{lhs_batch_axis}, {rhs_batch_axis}, {lhs_prefix}, {rhs_prefix}, {lhs_shape}, {rhs_shape}" - - # for case [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c] - if _can_broadcast_to(lhs_prefix, rhs_prefix): - # [m, ..., n] - batched_part = [rhs_prefix[ax] for ax in rhs_batch_axis] - # [i, ..., j] - nonbatched_part = rhs_prefix[0 : len(rhs_prefix) - len(rhs_batch_axis)] - - # in xla, [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c] -> [m, ..., n, a, i, ..., j, c] - # in mge, [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c] -> [i, ..., j, m, ..., n, a, c] - # so we need permute - xla_oshape = [*batched_part, lhs_shape[-2], *nonbatched_part, rhs_shape[-1]] - nonbatched_perm = [ - idx + 1 + len(batched_part) for idx in range(len(nonbatched_part)) - ] - batched_perm = [idx for idx in range(len(batched_part))] - permutation = [ - *nonbatched_perm, - *batched_perm, - len(batched_part), - len(xla_oshape) - 1, - ] - # for case [i, ..., j, m, ..., n, a, b] * [m, ..., n, b, c] - else: - # [m, ..., n] - batched_part = [lhs_prefix[ax] for ax in lhs_batch_axis] - # [i, ..., j] - nonbatched_part = lhs_prefix[0 : len(lhs_prefix) - len(lhs_batch_axis)] - - # in xla, [i, ..., j, m, ... , n, a, b] * [m, ..., n, b, c] -> [m, ..., n, i, ..., j, a, c] - # in mge, [i, ..., j, m, ... , n, a, b] * [m, ..., n, b, c] -> [i, ..., j, m, ..., n, a, c] - # so we need permute - xla_oshape = [*batched_part, *nonbatched_part, lhs_shape[-2], rhs_shape[-1]] - nonbatched_perm = [ - idx + len(batched_part) for idx in range(len(nonbatched_part)) - ] - batched_perm = [idx for idx in range(len(batched_part))] - permutation = [ - *nonbatched_perm, - *batched_perm, - len(xla_oshape) - 2, - len(xla_oshape) - 1, - ] - - return ( - lhs_reduce_axis, - rhs_reduce_axis, - lhs_batch_axis, - rhs_batch_axis, - xla_oshape, - permutation, - ) - - -@register_lower_rule(mops.BatchedMatrixMul) -def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2 - assert ( - ctx.op.compute_mode == mops.BatchedMatrixMul.ComputeMode.DEFAULT - ), f"{ctx.op.compute_mode}" - assert ctx.op.format == mops.BatchedMatrixMul.Format.DEFAULT, f"{ctx.op.format}" - assert ctx.op.dimA == len(args[0].shape) and ctx.op.dimB == len( - args[1].shape - ), f"{ctx.op.dimA}, {ctx.op.dimB}, {args[0].shape}, {args[1].shape}" - assert args[0].ndim >= 2 and args[1].ndim >= 2, f"{args[0].shape}, {args[1].shape}" - lhs, rhs = args[0], args[1] - - ( - lhs_reduce_axis, - rhs_reduce_axis, - lhs_batch_axis, - rhs_batch_axis, - xla_oshape, - permutation, - ) = _bmm_shape_helper(lhs.shape, rhs.shape, ctx.op.transposeA, ctx.op.transposeB) - - # in xla, [3, 4, 5, 6] * [3, 1, 6, 7] is illegal, so we broadcast [3, 1, 6, 7] -> [3, 4, 6, 7] - if _can_broadcast_to(lhs.shape[:-2], rhs.shape[:-2]): - lshape = [ - rhs.shape[r] if lhs.shape[l] == 1 else lhs.shape[l] - for l, r in zip(lhs_batch_axis, rhs_batch_axis) - ] - lshape = [*lshape, *lhs.shape[-2:]] - if not _shape_equal(lshape, lhs.shape): - lhs = lhs.broadcast_to(lshape) - else: - assert _can_broadcast_to(rhs.shape[:-2], lhs.shape[:-2]) - rshape = [ - lhs.shape[l] if rhs.shape[r] == 1 else rhs.shape[r] - for l, r in zip(lhs_batch_axis, rhs_batch_axis) - ] - rshape = [*rshape, *rhs.shape[-2:]] - if not _shape_equal(rshape, rhs.shape): - rhs = rhs.broadcast_to(rshape) - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch_axis), - rhs_batching_dimensions=list(rhs_batch_axis), - lhs_contracting_dimensions=(lhs_reduce_axis,), # the reduce axis in lhs - rhs_contracting_dimensions=(rhs_reduce_axis,), # the reduce axis in rhs - ) - - return HLOTensor( - hlo.DotGeneralOp( - ir_utils.make_ir_type_according_meta(xla_oshape, ctx.vars_out[0].dtype), - lhs.tensor, - rhs.tensor, - dot_dnums, - precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype), - ).result - ).transpose(permutation) diff --git a/imperative/python/megengine/xla/rules/nn.py b/imperative/python/megengine/xla/rules/nn.py deleted file mode 100644 index 31dc4f68d..000000000 --- a/imperative/python/megengine/xla/rules/nn.py +++ /dev/null @@ -1,681 +0,0 @@ -from functools import partial -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir import ir -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .indexing import index_with_slices -from .reduction import _get_max_identity, _get_sum_identity -from .tensor import fill, pad, reshape -from .utils import register_lower_rule - - -@register_lower_rule(mops.Convolution) -def convolution_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert isinstance(ctx.op, mops.Convolution) - assert len(args) == 2, "convolution requires 2 arguments" - assert len(ctx.vars_in) == 2, "convolution requires 2 input variables" - assert len(ctx.vars_out) == 1, "convolution requires 1 output variable" - - opr = ctx.op - inp, weight = args[0], args[1] - - if opr.format == mops.AdaptivePooling.Format.NCHW: - inp_spec, weight_spec, out_spec = (0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3) - dnums = hlo.ConvDimensionNumbers.get( - input_batch_dimension=inp_spec[0], - input_feature_dimension=inp_spec[1], - input_spatial_dimensions=list(inp_spec[2:]), - kernel_output_feature_dimension=weight_spec[0], - kernel_input_feature_dimension=weight_spec[1], - kernel_spatial_dimensions=list(weight_spec[2:]), - output_batch_dimension=out_spec[0], - output_feature_dimension=out_spec[1], - output_spatial_dimensions=list(out_spec[2:]), - ) - ic = inp.shape[1] # NCHW - oc = weight.shape[0] # OIHW or O11HW for dwconv - else: - assert False, "only nchw supported" - - num_spatial_dims = len(weight_spec) - 2 - window_reversal = ir_utils.dense_bool_elements([False] * num_spatial_dims) - - if opr.sparse == mops.BatchConvBias.Sparse.DENSE: - feature_group_count, batch_group_count = 1, 1 - else: - assert ic == oc, "dwconv only support ic == oc" - assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 - - if opr.format == mops.AdaptivePooling.Format.NCHW: - assert ( - weight.shape[1] == 1 and weight.shape[2] == 1 - ), f"weight shape error: {weight.shape}" - xla_weight_shape = [weight.shape[i] for i in [0, 2, 3, 4]] - weight = reshape(weight, xla_weight_shape) - - feature_group_count = ir_utils.i64_attr(feature_group_count) - batch_group_count = ir_utils.i64_attr(batch_group_count) - - window_strides = (opr.stride_h, opr.stride_w) - window_strides = ir_utils.dense_int_elements(window_strides) - - padding = ((opr.pad_h, opr.pad_h), (opr.pad_w, opr.pad_w)) - padding = ir_utils.dense_int_elements(padding) - - assert opr.dilate_h == 1 and opr.dilate_w == 1, "dilate_conv is not support now" - inp_dilation = (opr.dilate_h, opr.dilate_w) - weight_dilation = (opr.dilate_h, opr.dilate_w) - inp_dilation = ir_utils.dense_int_elements(inp_dilation) - weight_dilation = ir_utils.dense_int_elements(weight_dilation) - - window_reversal = ir_utils.dense_bool_elements([False] * num_spatial_dims) - precision = ir_utils.precision_attr(inp.dtype, weight.dtype) - - return HLOTensor( - hlo.ConvolutionOp( - ir_utils.mge_varinfo_to_ir_type(ctx.vars_out[0]), - inp.tensor, - weight.tensor, - dimension_numbers=dnums, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - window_strides=window_strides, - padding=padding, - lhs_dilation=inp_dilation, - rhs_dilation=weight_dilation, - window_reversal=window_reversal, - precision_config=precision, - ).result, - ctx.vars_out[0].shape, - ctx.vars_out[0].dtype, - ) - - -def _dilate_shape(shape, dilation): - """Utility function for computing the shape resulting from a dilation.""" - if not np.all(np.greater(dilation, 0)): - msg = "All dilations must be positive, got {}." - raise TypeError(msg.format(dilation)) - dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation) - - def dilate_dim(d, dilation): - return 0 if d == 0 else 1 + dilation * (d - 1) - - return tuple(map(dilate_dim, shape, dilation)) - - -def _conv_general_vjp_lhs_padding( - in_shape, - window_dimensions, - window_strides, - out_shape, - padding, - lhs_dilation, - rhs_dilation, -): - lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation) - rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation) - out_dilated_shape = _dilate_shape(out_shape, window_strides) - pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 - pad_after = ( - np.add(lhs_dilated_shape, rhs_dilated_shape) - - 1 - - out_dilated_shape - - pad_before - ) - return list(zip(pad_before, pad_after)) - - -def _conv_general_vjp_rhs_padding( - in_shape, - window_dimensions, - window_strides, - out_shape, - padding, - lhs_dilation, - rhs_dilation, -): - def diff_shape(s1, s2): - return tuple(map(lambda a, b: a - b, s1, s2)) - - if len(in_shape) == 0: # 0D conv - return [] - lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation) - rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation) - out_dilated_shape = _dilate_shape(out_shape, window_strides) - pads_lo = tuple(map(lambda p: p[0], padding)) - pads_from_lhs = diff_shape(out_dilated_shape, lhs_dilated_shape) - pads_from_rhs = diff_shape( - diff_shape(rhs_dilated_shape, pads_lo), (1,) * len(pads_lo) - ) - pads_hi = tuple(map(lambda *s: sum(s), pads_from_lhs, pads_from_rhs)) - return list(zip(pads_lo, pads_hi)) - - -@register_lower_rule("ConvolutionBackwardDataV2") -def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 3 and len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3 - assert ( - ctx.param["dilate_h"] == 1 and ctx.param["dilate_w"] == 1 - ), "dilate_conv is not support now" - - weight, dout, inp = args[0], args[1], args[2] - if ctx.param["format"] == mops.AdaptivePooling.Format.NCHW: - dnums = ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3)) - inp_spec, weight_spec, out_spec = dnums - inp_hw, weight_hw, out_hw = map(lambda s: s[2:], dnums) - inp_dilation = (1, 1) - weight_dilation = (ctx.param["dilate_h"], ctx.param["dilate_w"]) - window_strides = (ctx.param["stride_h"], ctx.param["stride_w"]) - ph, pw = ctx.param["pad_h"], ctx.param["pad_w"] - padding = ((ph, ph), (pw, pw)) - weight_shape = weight.shape - inp_shape = inp.shape - ic = inp.shape[1] # NCHW - oc = weight.shape[0] # OIHW or O11HW for dwconv - t_weight_spec = (weight_spec[1], weight_spec[0]) + weight_spec[2:] - dnums = hlo.ConvDimensionNumbers.get( - input_batch_dimension=out_spec[0], - input_feature_dimension=out_spec[1], - input_spatial_dimensions=list(out_spec[2:]), - kernel_output_feature_dimension=t_weight_spec[0], - kernel_input_feature_dimension=t_weight_spec[1], - kernel_spatial_dimensions=list(t_weight_spec[2:]), - output_batch_dimension=inp_spec[0], - output_feature_dimension=inp_spec[1], - output_spatial_dimensions=list(inp_spec[2:]), - ) - - if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE: - feature_group_count, batch_group_count = 1, 1 - else: - assert ic == oc, "only support dpwise conv currently" - assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 - weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]] - weight = weight.reshape(weight_shape) - - padding = _conv_general_vjp_lhs_padding( - np.take(inp_shape, inp_hw), - np.take(weight_shape, weight_hw), - window_strides, - np.take(dout.shape, out_hw), - padding, - inp_dilation, - weight_dilation, - ) - - rev_filter = HLOTensor( - hlo.ReverseOp(weight.tensor, ir_utils.dense_int_elements(weight_hw)).result - ) - window_reversal = ir_utils.dense_bool_elements([False] * (len(weight_spec) - 2)) - precision = ir_utils.precision_attr(rev_filter.dtype, dout.dtype) - return HLOTensor( - hlo.ConvolutionOp( - ir_utils.mge_varinfo_to_ir_type(ctx.vars_out[0]), - dout.tensor, - rev_filter.tensor, - dimension_numbers=dnums, - feature_group_count=ir_utils.i64_attr(feature_group_count), - batch_group_count=ir_utils.i64_attr(batch_group_count), - window_strides=ir_utils.dense_int_elements(inp_dilation), - padding=ir_utils.dense_int_elements(padding), - lhs_dilation=ir_utils.dense_int_elements(window_strides), - rhs_dilation=ir_utils.dense_int_elements(weight_dilation), - window_reversal=window_reversal, - precision_config=precision, - ).result - ) - else: - assert False, "only nchw supported" - - -@register_lower_rule("ConvolutionBackwardFilterV2") -def conv_backward_filter_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ( - ctx.param["dilate_h"] == 1 and ctx.param["dilate_w"] == 1 - ), "dilate_conv is not support now" - assert len(args) == 3 and len(ctx.vars_out) == 1 and len(ctx.vars_in) == 3 - inp, dout, weight = args[0], args[1], args[2] - - if ctx.param["format"] == mops.AdaptivePooling.Format.NCHW: - dnums = ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3)) - _, weight_spec, _ = dnums - inp_hw, weight_hw, out_hw = map(lambda s: s[2:], dnums) - inp_trans, weight_trans, out_trans = map(lambda s: (s[1], s[0]) + s[2:], dnums) - inp_dilation = (1, 1) - weight_dilation = (ctx.param["dilate_h"], ctx.param["dilate_w"]) - window_strides = (ctx.param["stride_h"], ctx.param["stride_w"]) - ph, pw = ctx.param["pad_h"], ctx.param["pad_w"] - padding = ((ph, ph), (pw, pw)) - weight_shape = weight.shape - inp_shape = inp.shape - ic = inp.shape[1] # NCHW - oc = weight.shape[0] # OIHW or O11HW for dwconv - if ctx.param["sparse"] == mops.BatchConvBias.Sparse.DENSE: - feature_group_count, batch_group_count = 1, 1 - else: - assert ic == oc, "only support dpwise conv currently" - assert len(weight.shape) == 5, "mge dpconv weight dim is 5" - feature_group_count, batch_group_count = ic, 1 - weight_shape = [weight.shape[i] for i in [2, 0, 3, 4]] - - if batch_group_count > 1: - feature_group_count = batch_group_count - batch_group_count = 1 - elif feature_group_count > 1: - batch_group_count = feature_group_count - feature_group_count = 1 - padding = _conv_general_vjp_rhs_padding( - np.take(inp_shape, inp_hw), - np.take(weight_shape, weight_hw), - window_strides, - np.take(dout.shape, out_hw), - padding, - inp_dilation, - weight_dilation, - ) - - dnums = hlo.ConvDimensionNumbers.get( - input_batch_dimension=inp_trans[0], - input_feature_dimension=inp_trans[1], - input_spatial_dimensions=list(inp_trans[2:]), - kernel_output_feature_dimension=out_trans[0], - kernel_input_feature_dimension=out_trans[1], - kernel_spatial_dimensions=list(out_trans[2:]), - output_batch_dimension=weight_trans[0], - output_feature_dimension=weight_trans[1], - output_spatial_dimensions=list(weight_trans[2:]), - ) - if batch_group_count > 1: - oup = ir.RankedTensorType.get( - [weight_shape[1], weight_shape[0]] + weight_shape[2:], - ir_utils.mge_dtype_to_ir_type(ctx.vars_out[0].dtype), - ) - else: - oup = ir_utils.mge_varinfo_to_ir_type(ctx.vars_out[0]) - window_reversal = ir_utils.dense_bool_elements([False] * (len(weight_spec) - 2)) - precision = ir_utils.precision_attr(inp.dtype, dout.dtype) - rst = HLOTensor( - hlo.ConvolutionOp( - oup, - inp.tensor, - dout.tensor, - dimension_numbers=dnums, - feature_group_count=ir_utils.i64_attr(feature_group_count), - batch_group_count=ir_utils.i64_attr(batch_group_count), - window_strides=ir_utils.dense_int_elements(weight_dilation), - padding=ir_utils.dense_int_elements(padding), - lhs_dilation=ir_utils.dense_int_elements(inp_dilation), - rhs_dilation=ir_utils.dense_int_elements(window_strides), - window_reversal=window_reversal, - precision_config=precision, - ).result - ) - if batch_group_count > 1: - rst = rst.reshape(ctx.vars_out[0].shape) - return rst - else: - assert False, "only nchw supported" - - -def _pooling( - reducer, - unit_factory, - inp, - stride, - kernel, - padding, - base_dilation=None, - kernel_dilation=None, - oshape=None, -): - """ - if pooling on H and W, - stride: len(stride) need to be equal to len(inp.shape) - for NCHW, should be (1, 1, stride_h, stride_w) - for NHWC, should be (1, stride_h, stride_w, 1) - kernel: similar to stride, len(kernel) also need to be equal to len(inp.shape) - padding: similar - for NCHW, should be ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w)) or (0, 0, pad_h, pad_w) - for NHWC, should be ((0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)) or (0, pad_h, pad_w, 0) - """ - ishape, idtype = inp.shape, inp.dtype - assert oshape is not None, "pooling shape infer is not supported" - assert len(ishape) == len(oshape), f"shape error: {ishape} {oshape}" - - def check_param(param, info): - assert len(ishape) == len( - param - ), f"pooling: illegal {info} {param} for {ishape}" - - base_dilation = base_dilation if base_dilation is not None else (1, 1, 1, 1) - kernel_dilation = kernel_dilation if kernel_dilation is not None else (1, 1, 1, 1) - padding = [(p, p) if isinstance(p, int) else p for p in padding] - - check_param(stride, "stride") - check_param(kernel, "kernel") - check_param(padding, "padding") - check_param(base_dilation, "base_dilation") - check_param(kernel_dilation, "kernel_dilation") - - rw = hlo.ReduceWindowOp( - ir_utils.make_ir_type_according_meta_tuple(oshape, idtype), - [inp.tensor], - ir_utils.ir_constant_tuple(unit_factory(idtype)), - ir_utils.dense_int_elements(kernel), - window_strides=ir_utils.dense_int_elements(stride), - base_dilations=ir_utils.dense_int_elements(base_dilation), - window_dilations=ir_utils.dense_int_elements(kernel_dilation), - padding=ir.DenseIntElementsAttr.get( - np.asarray(padding, np.int64), shape=(len(padding), 2) - ), - ) - scalar_type = ir_utils.make_ir_type_according_meta(tuple(), idtype) - reducer_region = rw.regions[0].blocks.append(scalar_type, scalar_type) - with ir.InsertionPoint(reducer_region): - hlo.ReturnOp(reducer(*reducer_region.arguments)) - return HLOTensor(rw.result) - - -maxpooling = partial(_pooling, hlo.MaxOp, _get_max_identity) -sumpooling = partial(_pooling, hlo.AddOp, _get_sum_identity) - - -def avgpooling( - inp, - stride, - kernel, - padding, - count_include_pad, - base_dilation=None, - kernel_dilation=None, - oshape=None, -): - sum_pool = sumpooling( - inp, stride, kernel, padding, base_dilation, kernel_dilation, oshape=oshape - ) - if count_include_pad: - ret = sum_pool / float(np.prod(kernel)) - else: - # for inp[a,b,c,d], kernel[1,1,2,2], oshape[a,b,e,f] - # div_ishape=[1,1,c,d], div_oshape=[1,1,e,f] - div_ishape = [i if k != 1 else 1 for (k, i) in zip(kernel, inp.shape)] - div_oshape = [o if k != 1 else 1 for (k, o) in zip(kernel, oshape)] - divider = fill(1.0, div_ishape, inp.dtype) - divider = sumpooling(divider, stride, kernel, padding, oshape=div_oshape) - ret = sum_pool / divider - return ret - - -def _get_adaptive_pool_param(ishape, oshape, tensor_format): - assert len(ishape) == 4 and len(oshape) == 4, "only 2-d pooling supported" - if not isinstance(tensor_format, str): - tensor_format = str(tensor_format) - - ishape_hw, oshape_hw = None, None - if tensor_format in str(mops.AdaptivePooling.Format.NCHW): - ishape_hw, oshape_hw = ishape[2:4], oshape[2:4] - elif tensor_format in str(mops.AdaptivePooling.Format.NHWC): - ishape_hw, oshape_hw = ishape[1:3], oshape[1:3] - else: - assert False, f"adaptive pooling only nchw or nhwc, get {tensor_format}" - - stride_hw = [(isize // osize) for isize, osize in zip(ishape_hw, oshape_hw)] - kernel_hw = [ - (isize - (osize - 1) * stride) - for isize, osize, stride in zip(ishape_hw, oshape_hw, stride_hw) - ] - - stride, kernel = None, None - if tensor_format in str(mops.AdaptivePooling.Format.NCHW): - stride = (1, 1, *stride_hw) - kernel = (1, 1, *kernel_hw) - elif tensor_format in str(mops.AdaptivePooling.Format.NHWC): - stride = (1, *stride_hw, 1) - kernel = (1, *kernel_hw, 1) - else: - assert False, f"adaptive pooling only nchw or nhwc, get {tensor_format}" - padding = (0, 0, 0, 0) - - return kernel, stride, padding - - -def _select_and_scatter( - inp, source, init_value, kernel, stride, padding, selector, scatter -): - oshape, odtype = inp.shape, inp.dtype - scalar_type = ir_utils.make_ir_type_according_meta(tuple(), odtype) - op = hlo.SelectAndScatterOp( - ir_utils.make_ir_type_according_meta(oshape, odtype), - inp.tensor, - source.tensor, - HLOTensor(init_value).tensor, - window_dimensions=ir_utils.dense_int_elements(kernel), - window_strides=ir_utils.dense_int_elements(stride), - padding=ir.DenseIntElementsAttr.get( - np.asarray(padding, np.int64), shape=(len(padding), 2) - ), - ) - - select_block = op.select.blocks.append(scalar_type, scalar_type) - with ir.InsertionPoint(select_block): - blockargs = [HLOTensor(blockarg) for blockarg in select_block.arguments] - hlo.ReturnOp([selector(*blockargs).tensor]) - - scatter_block = op.scatter.blocks.append(scalar_type, scalar_type) - with ir.InsertionPoint(scatter_block): - blockargs = [HLOTensor(blockarg) for blockarg in scatter_block.arguments] - hlo.ReturnOp([scatter(*blockargs).tensor]) - - return HLOTensor(op.result) - - -def maxpooling_grad( - x, - dy, - kernel, - stride, - padding, - base_dilation=None, - kernel_dilation=None, - expand_padding=True, -): - assert base_dilation is None and kernel_dilation is None - assert expand_padding == True - padding = [(p, p) if isinstance(p, int) else p for p in padding] - dxdtype, dxshape = x.dtype, x.shape - assert dxdtype == "float32" or dxdtype == "float16" - - org_padding, new_padding = padding, padding - if expand_padding: - pads = [(lo, hi, 0) for (lo, hi) in padding] - padded_x = pad(x, _get_max_identity(dxdtype), pads) - new_padding = [(0, 0) for _ in padding] - - selector = lambda x, y: x >= y - scatter = lambda x, y: x + y - out = _select_and_scatter( - padded_x, dy, 0.0, kernel, stride, new_padding, selector, scatter - ) - - if expand_padding: - start_indices = [lo for (lo, hi) in org_padding] - stop_indices = [lo + d for ((lo, hi), d) in zip(org_padding, dxshape)] - slices = [ - slice(start, stop, 1) for start, stop in zip(start_indices, stop_indices) - ] - out = index_with_slices(out, slices) - - return out - - -def avgpooling_grad( - x, - dy, - kernel, - stride, - padding, - base_dilation=None, - kernel_dilation=None, - count_include_pad=True, -): - padding = [(p, p) if isinstance(p, int) else p for p in padding] - base_dilation = base_dilation if base_dilation is not None else (1, 1, 1, 1) - kernel_dilation = kernel_dilation if kernel_dilation is not None else (1, 1, 1, 1) - - if count_include_pad: - dy = dy / float(np.prod(kernel)) - else: - div_ishape = [i if k != 1 else 1 for (k, i) in zip(kernel, x.shape)] - div_oshape = [o if k != 1 else 1 for (k, o) in zip(kernel, dy.shape)] - divider = fill(1.0, div_ishape, dy.dtype) - divider = sumpooling(divider, stride, kernel, padding, oshape=div_oshape) - dy = dy / divider - - pads = _conv_general_vjp_lhs_padding( - x.shape, kernel, stride, dy.shape, padding, base_dilation, kernel_dilation - ) - padding_dy_config = [(lo, hi, st - 1) for (lo, hi), st in zip(pads, stride)] - padded_dy = pad(dy, _get_sum_identity(dy.dtype), padding_dy_config) - - ret = sumpooling( - padded_dy, - stride=base_dilation, - kernel=kernel, - padding=[(0, 0)] * len(x.shape), - base_dilation=(1, 1, 1, 1), - kernel_dilation=kernel_dilation, - oshape=x.shape, - ) - return ret - - -@register_lower_rule(mops.AdaptivePooling) -def adaptive_pooling_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(ctx.vars_in) == 2 and len(args) == 2 and len(ctx.vars_out) == 1 - assert ctx.op.shape == ctx.vars_in[1].bound_data.tolist() and len(ctx.op.shape) == 2 - - ishape, oshape = ctx.vars_in[0].shape, ctx.vars_out[0].shape - kernel, stride, padding = _get_adaptive_pool_param(ishape, oshape, ctx.op.format) - - if ctx.op.mode == mops.AdaptivePooling.Mode.AVERAGE: - return avgpooling( - args[0], stride, kernel, padding, count_include_pad=True, oshape=oshape - ) - elif ctx.op.mode == mops.AdaptivePooling.Mode.AVERAGE_COUNT_EXCLUDE_PADDING: - return avgpooling( - args[0], stride, kernel, padding, count_include_pad=False, oshape=oshape - ) - else: - assert ( - ctx.op.mode == mops.AdaptivePooling.Mode.MAX - ), f"unknown adaptive pooling mode {ctx.op.mode}" - return maxpooling(args[0], stride, kernel, padding, oshape=oshape) - - -@register_lower_rule("AdaptivePoolingBackwardV1") -def adaptive_pooling_grad_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - # for forward: y = adaptive_pool(x, tshape) - # for backward: dx = adaptive_pool_grad(x, tshape, y, dy) - assert len(args) == 4 and len(ctx.vars_in) == 4 and len(ctx.vars_out) == 1 - var_x, _, var_y, _ = ctx.vars_in - x, dy = args[0], args[3] - tensor_format, pool_mode = ctx.param["format"], ctx.param["mode"] - kernel, stride, padding = _get_adaptive_pool_param( - var_x.shape, var_y.shape, tensor_format - ) - - if pool_mode in str(mops.AdaptivePooling.Mode.AVERAGE): - return avgpooling_grad(x, dy, kernel, stride, padding, count_include_pad=True) - elif pool_mode in str(mops.AdaptivePooling.Mode.AVERAGE_COUNT_EXCLUDE_PADDING): - return avgpooling_grad(x, dy, kernel, stride, padding, count_include_pad=False) - else: - assert pool_mode in str( - mops.AdaptivePooling.Mode.MAX - ), f"unknown adaptive pooling mode {pool_mode}" - return maxpooling_grad(x, dy, kernel, stride, padding) - - -def _get_pool_param(kernel_hw, stride_hw, padding_hw, tensor_format): - assert len(kernel_hw) == 2 and len(stride_hw) == 2 and len(padding_hw) == 2 - # for backward, the tensor format is str - if not isinstance(tensor_format, str): - tensor_format = str(tensor_format) - - stride, kernel, padding = None, None, None - if tensor_format in str(mops.AdaptivePooling.Format.NCHW): - stride = (1, 1, *stride_hw) - kernel = (1, 1, *kernel_hw) - padding = (0, 0, *padding_hw) - elif tensor_format in str(mops.AdaptivePooling.Format.NHWC): - stride = (1, *stride_hw, 1) - kernel = (1, *kernel_hw, 1) - padding = (0, *padding_hw, 0) - else: - assert False, f"adaptive pooling only nchw or nhwc, get {tensor_format}" - - return kernel, stride, padding - - -@register_lower_rule(mops.Pooling) -def pooling_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1, f"pooling should have only 1 input, but give {len(args)}" - assert len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 - assert ( - args[0].ndim == 4 - ), f"pooling only support 4d tensor, but give {args[0].shape}" - opr = ctx.op - kernel, stride, padding = _get_pool_param( - (opr.window_h, opr.window_w), - (opr.stride_h, opr.stride_w), - (opr.pad_h, opr.pad_w), - opr.format, - ) - - oshape, odtype = ctx.vars_out[0].shape, ctx.vars_out[0].dtype - if opr.mode == mops.AdaptivePooling.Mode.AVERAGE: - return avgpooling( - args[0], stride, kernel, padding, count_include_pad=True, oshape=oshape - ) - elif opr.mode == mops.AdaptivePooling.Mode.AVERAGE_COUNT_EXCLUDE_PADDING: - return avgpooling( - args[0], stride, kernel, padding, count_include_pad=False, oshape=oshape - ) - else: - assert ( - opr.mode == mops.AdaptivePooling.Mode.MAX - ), f"unknown adaptive pooling mode {opr.mode}" - return maxpooling(args[0], stride, kernel, padding, oshape=oshape) - - -@register_lower_rule("PoolingBackwardV1") -def pooling_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - # for forward: y = pool(x) - # for backward: dx = pool_grad(x, y, dy) - assert len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 1 - tensor_format, pool_mode = ctx.param["format"], ctx.param["mode"] - - kernel, stride, padding = _get_pool_param( - (ctx.param["window_h"], ctx.param["window_w"]), - (ctx.param["stride_h"], ctx.param["stride_w"]), - (ctx.param["pad_h"], ctx.param["pad_w"]), - tensor_format, - ) - - x, dy = args[0], args[2] - if pool_mode in str(mops.AdaptivePooling.Mode.AVERAGE): - return avgpooling_grad(x, dy, kernel, stride, padding, count_include_pad=True) - elif pool_mode in str(mops.AdaptivePooling.Mode.AVERAGE_COUNT_EXCLUDE_PADDING): - return avgpooling_grad(x, dy, kernel, stride, padding, count_include_pad=False) - else: - assert pool_mode in str( - mops.AdaptivePooling.Mode.MAX - ), f"unknown adaptive pooling mode {pool_mode}" - return maxpooling_grad(x, dy, kernel, stride, padding) diff --git a/imperative/python/megengine/xla/rules/normalize.py b/imperative/python/megengine/xla/rules/normalize.py deleted file mode 100644 index 38a9c987a..000000000 --- a/imperative/python/megengine/xla/rules/normalize.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir.dialects import hlo -from .elemwise import sqrt -from .hlotensor import HLOTensor -from .utils import register_lower_rule - - -@register_lower_rule(mops.BatchNorm) -def batch_norm_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - if ctx.op.fwd_mode == mops.BatchNorm.FwdMode.TRAINING: - # training mode will return the new running mean and var, so return 6 args - assert ( - len(args) == 5 and len(ctx.vars_in) == 5 and len(ctx.vars_out) == 6 - ), f"len(args): {len(args)}, len(ctx.vars_in): {len(ctx.vars_in)}, len(ctx.vars_out): {len(ctx.vars_out)}" - else: - assert ctx.op.fwd_mode == mops.BatchNorm.FwdMode.INFERENCE, f"{ctx.op.fwd_mode}" - # inference mode will not return the new running mean and var, so return 4 args - assert ( - len(args) == 5 and len(ctx.vars_in) == 5 and len(ctx.vars_out) == 4 - ), f"len(args): {len(args)}, len(ctx.vars_in): {len(ctx.vars_in)}, len(ctx.vars_out): {len(ctx.vars_out)}" - - assert ctx.op.param_dim == "DIM_1C11", f"ctx.op.param_dim: {ctx.op.param_dim}" - - channel_dim = 1 # because param_dim is DIM_1C11 - C = args[1].shape[channel_dim] - - inp, weight, bias, running_mean, running_var = ( - args[0], - args[1], - args[2], - args[3], - args[4], - ) - unused = HLOTensor( - np.random.random(ctx.vars_out[-2].shape).astype(ctx.vars_out[-2].dtype) - ) - - if ctx.op.fwd_mode == mops.BatchNorm.FwdMode.TRAINING: - rst = hlo.BatchNormTrainingOp( - inp.tensor, - weight.reshape((C,)).tensor, - bias.reshape((C,)).tensor, - ir_utils.f32_attr(ctx.op.epsilon), - ir_utils.i64_attr(channel_dim), - ).results - assert len(rst) == 3, f"len(rst): {len(rst)}" - oup, batch_mean, batch_var = ( - HLOTensor(rst[0]), - HLOTensor(rst[1]).reshape((1, C, 1, 1)), - HLOTensor(rst[2]).reshape((1, C, 1, 1)), - ) - - running_mean = ( - running_mean * (1 - ctx.op.avg_factor) + batch_mean * ctx.op.avg_factor - ) - running_var = ( - running_var * (1 - ctx.op.avg_factor) + batch_var * ctx.op.avg_factor - ) - return running_mean, running_var, batch_mean, batch_var, unused, oup - else: - rst = hlo.BatchNormInferenceOp( - inp.tensor, - weight.reshape((C,)).tensor, - bias.reshape((C,)).tensor, - running_mean.reshape((C,)).tensor, - running_var.reshape((C,)).tensor, - ir_utils.f32_attr(ctx.op.epsilon), - ir_utils.i64_attr(channel_dim), - ).results - assert len(rst) == 1, f"len(rst): {len(rst)}" - oup = HLOTensor(rst[0]) - return running_mean, running_var, unused, oup - - -@register_lower_rule(mops.BatchNormBackward) -def batch_norm_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ( - len(args) == 6 and len(ctx.vars_in) == 6 and len(ctx.vars_out) == 3 - ), f"len(args): {len(args)}, len(ctx.vars_in): {len(ctx.vars_in)}, len(ctx.vars_out): {len(ctx.vars_out)}" - assert ( - ctx.op.fwd_mode == "TRAINING" and ctx.op.param_dim == "DIM_1C11" - ), f"ctx.op.fwd_mode: {ctx.op.fwd_mode}, ctx.op.param_dim: {ctx.op.param_dim}" - channel_dim = 1 # because param_dim is DIM_1C11 - C = args[4].shape[channel_dim] - - inp, grad, mean, var, weight = ( - args[0], - args[1], - args[2].reshape((C,)), - args[3].reshape((C,)), - args[4].reshape((C,)), - ) - rst = hlo.BatchNormGradOp( - inp.tensor, - weight.tensor, - mean.tensor, - var.tensor, - grad.tensor, - ir_utils.f32_attr(ctx.op.epsilon), - ir_utils.i64_attr(channel_dim), - ).results - return [ - HLOTensor(rst[1]).reshape(ctx.vars_out[0].shape), - HLOTensor(rst[2]).reshape(ctx.vars_out[1].shape), - HLOTensor(rst[0]), - ] - - -def _normalize_lower( - x, axes, affine, eps, w=None, b=None, -): - x_mean = x.mean(axes, True) - x_mean_sqr = x_mean * x_mean - x_sqr = x * x - x_sqr_mean = x_sqr.mean(axes, True) - var = x_sqr_mean - x_mean_sqr - var_plus_eps = var + eps - std = sqrt(var_plus_eps) - rstd = 1.0 / std - delta = x - x_mean - normalized = delta * rstd - - if affine: - normalized = normalized * w + b - - return [normalized, x_mean, rstd] - - -@register_lower_rule(mops.LayerNorm) -def layer_norm_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert ctx.op.normalized_dim > 0 - if ctx.op.affine: - assert len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 3 - x, w, b = args[0], args[1], args[2] - else: - assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 3 - x, w, b = args[0], None, None - - axes = list(range(x.ndim - ctx.op.normalized_dim, x.ndim)) - rets = _normalize_lower(x, axes, ctx.op.affine, ctx.op.eps, w, b,) - rets[1] = rets[1].reshape(ctx.vars_out[1].shape) - rets[2] = rets[2].reshape(ctx.vars_out[2].shape) - return rets - - -@register_lower_rule("LayerNormBackward") -def layer_norm_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - grad = args[0] - if ctx.param["affine"]: - inp = args[1] - weight = args[2] - mean = args[3] - rstd = args[4] - else: - inp = args[1] - mean = args[2] - rstd = args[3] - - reduced_shape = mean.shape + (1,) * (inp.ndim - mean.ndim) - reduce_axes = list(range(mean.ndim, inp.ndim)) - axes_divider = np.prod([inp.shape[i] for i in reduce_axes]).astype("float32") - - mean = mean.reshape(reduced_shape) - rstd = rstd.reshape(reduced_shape) - delta = inp - mean - a = grad * weight if ctx.param["affine"] else grad - reducea = a.sum(reduce_axes, True) - b = (a * delta * rstd ** 2.0).sum(reduce_axes, True) - x1 = a * rstd - x2 = -1.0 * rstd / axes_divider * reducea - x3 = -1.0 * rstd / axes_divider * (inp - mean) * b - x4 = rstd / axes_divider * ((inp - mean) / axes_divider * b).sum(reduce_axes, True) - dx = x1 + x2 + x3 + x4 - - if ctx.param["affine"]: - unreduce_axes = list(range(inp.ndim - weight.ndim)) - dbias = grad.sum(unreduce_axes, keepdims=False) - dweight = (delta * rstd * grad).sum(unreduce_axes, keepdims=False) - return [dx, dweight, dbias] - return [dx] diff --git a/imperative/python/megengine/xla/rules/random.py b/imperative/python/megengine/xla/rules/random.py deleted file mode 100644 index da062ea00..000000000 --- a/imperative/python/megengine/xla/rules/random.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib import xla_client as xc -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .utils import _shape_equal, register_lower_rule - -RandomAlgorithm = xc.ops.RandomAlgorithm -RandomAlgorithm.__str__ = lambda algorithm: algorithm.name - - -def _rng_algorithm(algorithm: RandomAlgorithm): - assert algorithm == RandomAlgorithm.RNG_THREE_FRY - if algorithm == RandomAlgorithm.RNG_THREE_FRY: - return hlo.RngAlgorithmAttr.get("THREE_FRY") - elif algorithm == RandomAlgorithm.RNG_PHILOX: - return hlo.RngAlgorithmAttr.get("PHILOX") - elif algorithm == RandomAlgorithm.RNG_DEFAULT: - return hlo.RngAlgorithmAttr.get("DEFAULT") - else: - assert False - - -def rng_uint_generator( - key, oshape, odtype="uint32", algorithm=RandomAlgorithm.RNG_THREE_FRY -): - - assert np.dtype(odtype) in { - np.dtype("uint8"), - np.dtype("uint16"), - np.dtype("uint32"), - np.dtype("uint64"), - }, f"only unsigned int supported, got {odtype}({type(odtype)})" - assert algorithm == RandomAlgorithm.RNG_THREE_FRY, "only ThreeFry supported now" - assert _shape_equal(key.shape, (2, 2)), f"key shape error, {key.shape}" - assert key.dtype == "int32", f"key dtype error, {key.dtype}" - - # bitcast (2x2,i32) -> (2,u64) - org_key_shape, org_key_dtype = key.shape, key.dtype - key = key.bitcast((2,), "uint64") - - if odtype == "uint32" or odtype == "uint64": - rng_odtype = odtype - else: - rng_odtype = "uint32" - - algorithm_attr = _rng_algorithm(algorithm) - new_key, out_vals = hlo.RngBitGeneratorOp( - ir_utils.make_ir_type_according_meta(key.shape, key.dtype), - ir_utils.make_ir_type_according_meta(oshape, rng_odtype), - algorithm_attr, - key.tensor, - ).results - new_key, out_vals = HLOTensor(new_key), HLOTensor(out_vals) - new_key = new_key.bitcast(org_key_shape, org_key_dtype) - - if rng_odtype != odtype: - out_vals = out_vals.astype(odtype) - return out_vals, new_key - - -@register_lower_rule(mops.Dropout) -def dropout_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(ctx.vars_in) == 2 and len(args) == 2 and len(ctx.vars_out) == 3 - inp, key = args - random_val, new_key = rng_uint_generator(key, inp.shape, "uint32") - mask = random_val > np.array(ctx.op.drop_prob * np.iinfo(np.uint32).max, np.uint32) - multiplier = mask.astype(inp.dtype) - multiplier = multiplier / (1.0 - ctx.op.drop_prob) - out = inp * multiplier - mask = mask.reshape((-1,)).astype("uint8") - return out, mask, new_key - - -@register_lower_rule("DropoutBackward") -def droupout_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 2 and len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 - dy, mask = args[0], args[1] - scale = 1.0 - ctx.param["drop_prob"] - multiplier = mask.reshape(dy.shape).astype(dy.dtype) / scale - return dy * multiplier diff --git a/imperative/python/megengine/xla/rules/reduction.py b/imperative/python/megengine/xla/rules/reduction.py deleted file mode 100644 index acbff3da8..000000000 --- a/imperative/python/megengine/xla/rules/reduction.py +++ /dev/null @@ -1,175 +0,0 @@ -from functools import partial -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir import ir -from ..lib.mlir.dialects import hlo -from .elemwise import div -from .hlotensor import HLOTensor -from .tensor import reshape -from .utils import _check_dtype, _check_shape, _shape_equal, register_lower_rule - - -def _get_sum_identity(dtype) -> np.ndarray: - return np.array(0, dtype) - - -def _get_prod_identity(dtype) -> np.ndarray: - return np.array(1, dtype) - - -def _get_max_identity(dtype) -> np.ndarray: - if dtype == np.float32 or dtype == np.float64 or dtype == np.float16: - return np.array(-np.inf, dtype) - elif ( - dtype == np.int32 or dtype == np.int64 or dtype == np.int16 or dtype == np.int8 - ): - return np.array(np.iinfo(dtype).min, dtype) - else: - assert False, f"unsupported dtype for max: {dtype}" - - -def _get_min_identity(dtype) -> np.ndarray: - if dtype == np.float32 or dtype == np.float64 or dtype == np.float16: - return np.array(np.inf, dtype) - elif ( - dtype == np.int32 or dtype == np.int64 or dtype == np.int16 or dtype == np.int8 - ): - return np.array(np.iinfo(dtype).max, dtype) - else: - assert False, f"unsupported dtype for max: {dtype}" - - -def _get_bitwise_and_identity(dtype) -> np.ndarray: - return np.array(-1).astype(dtype) - - -def _get_bitwise_or_identity(dtype) -> np.ndarray: - return np.array(0, dtype) - - -def _infer_reduce_shape(ishape, axes, keepdims=False): - if axes is None: - axes = list(range(len(ishape))) - - reduced_shape = [] - - for axis, length in enumerate(ishape): - if axis not in axes: - reduced_shape.append(length) - else: - if keepdims: - reduced_shape.append(1) - return tuple(reduced_shape) - - -def _reduce( - reducer, fidentity, inp, axes=None, keepdims=False, oshape=None, odtype=None -): - def _reduce_nokeepdim(reducer, fidentity, inp, axes=None, oshape=None, odtype=None): - axes = [axis if axis >= 0 else axis + inp.ndim for axis in axes] - reduced_shape = _infer_reduce_shape(inp.shape, axes) - - _check_shape(reduced_shape, oshape) - _check_dtype(inp.dtype, odtype) - - reduce_out = ir_utils.make_ir_type_according_meta(reduced_shape, inp.dtype) - init_val = ir_utils.ir_constant_tuple(fidentity(inp.dtype)) - reduce_op = hlo.ReduceOp( - [reduce_out], [inp.tensor], init_val, ir_utils.dense_int_elements(axes) - ) - scalar_type = ir_utils.make_ir_type_according_meta(tuple(), inp.dtype) - reducer_region = reduce_op.regions[0].blocks.append(scalar_type, scalar_type) - with ir.InsertionPoint(reducer_region): - reducer_ret = reducer(*reducer_region.arguments) - hlo.ReturnOp(reducer_ret.results) - - return HLOTensor(reduce_op.result) - - axes = [axes] if isinstance(axes, int) else axes - - maykeepdim_shape = _infer_reduce_shape(inp.shape, axes, keepdims) - _check_shape(maykeepdim_shape, oshape) - - oup = _reduce_nokeepdim(reducer, fidentity, inp, axes, oshape, odtype) - if _shape_equal(oup.shape, maykeepdim_shape): - return oup - else: - return reshape(oup, maykeepdim_shape) - - -sum = partial(_reduce, hlo.AddOp, _get_sum_identity) -prod = partial(_reduce, hlo.MulOp, _get_prod_identity) -max = partial(_reduce, hlo.MaxOp, _get_max_identity) -min = partial(_reduce, hlo.MinOp, _get_min_identity) -all = partial(_reduce, hlo.AndOp, _get_bitwise_and_identity) -any = partial(_reduce, hlo.OrOp, _get_bitwise_or_identity) - - -def mean(inp, axes=None, keepdims=False): - inp_sum = sum(inp, axes, keepdims) - inp_shape = inp.shape - - divider = 1.0 - for ax in axes: - divider *= inp_shape[ax] - - return div(inp_sum, divider) - - -@register_lower_rule(mops.Reduce) -def reduce_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): - assert ctx.op.data_type == mops.Reduce.DataType.DEFAULT - - opr = ctx.op - keepdims = opr.keepdim - if len(args) == 1: - assert isinstance(opr.axis, int) - if opr.axis < 0: - axes = opr.axis + args[0].ndim - else: - axes = (opr.axis,) - if opr.axis > 7: - axes = tuple(np.arange(args[0].ndim)) - keepdims = False - else: - assert len(args) == 2 - src_shape = args[0].shape - tgt_shape = list(ctx.module_context.get_value(ctx.vars_in[1])) - tgt_shape = [1,] * (len(src_shape) - len(tgt_shape)) + tgt_shape - src_idx, tgt_idx, axes = 0, 0, [] - while src_idx < len(src_shape) and tgt_idx < len(tgt_shape): - if src_shape[src_idx] != 1 and tgt_shape[tgt_idx] == 1: - axes.append(src_idx) - src_idx = src_idx + 1 - tgt_idx = tgt_idx + 1 - elif src_shape[src_idx] != tgt_shape[tgt_idx]: - axes.append(src_idx) - src_idx = src_idx + 1 - else: - src_idx = src_idx + 1 - tgt_idx = tgt_idx + 1 - assert tgt_idx == len( - tgt_shape - ), f"src_shape: {src_shape}, tgt_shape: {tgt_shape}" - axes = axes + list(range(src_idx, len(src_shape))) - - if opr.mode == mops.Reduce.Mode.SUM: - ret = sum(args[0], axes, keepdims) - elif opr.mode == mops.Reduce.Mode.MEAN: - ret = mean(args[0], axes, keepdims) - elif opr.mode == mops.Reduce.Mode.PRODUCT: - ret = prod(args[0], axes, keepdims) - elif opr.mode == mops.Reduce.Mode.MAX: - ret = max(args[0], axes, keepdims) - elif opr.mode == mops.Reduce.Mode.MIN: - ret = min(args[0], axes, keepdims) - else: - assert False, f"no support reduce mode {opr.mode}" - - if not _shape_equal(ret.shape, ctx.vars_out[0].shape): - ret = ret.reshape(ctx.vars_out[0].shape) - return ret diff --git a/imperative/python/megengine/xla/rules/tensor.py b/imperative/python/megengine/xla/rules/tensor.py deleted file mode 100644 index 4c31f19cc..000000000 --- a/imperative/python/megengine/xla/rules/tensor.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from .. import ir_utils -from ..lib.mlir.dialects import hlo -from .hlotensor import HLOTensor -from .utils import ( - _can_broadcast_to, - _check_shape, - _parse_var_as_value, - _shape_equal, - register_lower_rule, -) - - -def broadcast_to(inp, oshape, broadcast_dims=None): - """ - [x, y, z] or [x, 1, z] only can broadcast to [a, b, c, ..., x, y, z], rather than [x, y, z, ..., a, b, c] - but you can realize the latter specify broadcast_dims, for example: - broadcast_to([x, y, z], [x, y, z, ..., a, b, c], broadcast_dims=[0, 1, 2]) - - broadcast_dims specify which dimension in the target shape each dimension of the - operand shape corresponds to, for example: - (1, 64, 1, 1) -> (32, 64, 32, 32), broadcast_dims is [0, 1, 2, 3] - (16, 64, 32) -> (16, 64, 1, 32), broadcast_dims is [0, 1, 3] - """ - inp = HLOTensor(inp) if not isinstance(inp, HLOTensor) else inp - ishape, idtype = inp.shape, inp.dtype - if _shape_equal(ishape, oshape): - return inp - - assert _can_broadcast_to( - ishape, oshape, broadcast_dims - ), f"cannot broadcast {ishape} to {oshape} with broadcast_dims {broadcast_dims}" - - if broadcast_dims is None: - broadcast_dims = list(range(len(oshape) - len(ishape), len(oshape))) - - result = hlo.BroadcastInDimOp( - ir_utils.make_ir_type_according_meta(oshape, idtype), - inp.tensor, - ir_utils.dense_int_elements(broadcast_dims), - ).result - return HLOTensor(result, oshape, idtype) - - -def reshape(inp, oshape): - if -1 in oshape: - assert oshape.count(-1) == 1, f"invalid shape {oshape}" - oshape = list(oshape) - oshape[oshape.index(-1)] = int(np.abs(np.prod(inp.shape) / np.prod(oshape))) - - if _shape_equal(inp.shape, oshape): - return inp - - assert np.prod(inp.shape) == np.prod( - oshape - ), f"cannot reshape {inp.shape} to {oshape}" - - return HLOTensor( - hlo.ReshapeOp( - ir_utils.make_ir_type_according_meta(oshape, inp.dtype), inp.tensor - ).result, - oshape, - inp.dtype, - ) - - -def transpose(inp, permutation): - assert len(inp.shape) == len( - permutation - ), f"incompatible shape and permutation: {inp.shape} vs {permutation}" - return HLOTensor( - hlo.TransposeOp(inp.tensor, ir_utils.dense_int_elements(permutation)).result - ) - - -def expand_dims(inp, axis): - assert isinstance(axis, int), f"only int axis supported, get {axis}" - axis = (axis + inp.ndim) if axis < 0 else axis - assert axis >= 0 and axis <= inp.ndim, f"invalid axis {axis} for {inp.shape}" - - dst_shape = [] - for i in range(inp.ndim): - if i == axis: - dst_shape.append(1) - dst_shape.append(inp.shape[i]) - - return inp.reshape(tuple(dst_shape)) - - -@register_lower_rule(mops.Dimshuffle) -def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 - permutation = ctx.op.pattern - return transpose(args[0], permutation) - - -def concat(inps, axis): - assert len(inps) > 0, f"concat inputs should not be empty" - if axis < 0: - axis = axis + inps[0].ndim[0] - - hlo_inps = [inp.tensor for inp in inps] - - return HLOTensor(hlo.ConcatenateOp(hlo_inps, ir_utils.i64_attr(axis)).results) - - -@register_lower_rule(mops.Concat, "Concat") -def concat_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) > 1 and isinstance(ctx.param["axis"], int) - if ctx.param["axis"] < 0: - axis = ctx.param["axis"] + len(ctx.vars_in[0].shape) - else: - axis = ctx.param["axis"] - return concat(args, axis) - - -# if nsplit_or_sections is int, means divide inputs into nsplit_or_sections parts -# if nsplit_or_sections is Sequence[int], means divide inputs into -# len(nsplit_or_sections) parts, and the i-th part has nsplit_or_sections[i] elements -def split(inp, nsplit_or_sections, axis): - from .indexing import index_with_slices - - ishape = inp.shape - if axis < 0: - axis = axis + len(ishape) - - if isinstance(nsplit_or_sections, int): - dimlen = ishape[axis] - assert dimlen % nsplit_or_sections == 0, "not an equal division" - sections = [dimlen // nsplit_or_sections] * nsplit_or_sections - else: - sections = nsplit_or_sections - - assert np.sum(sections) == ishape[axis], "error split param" - - slices = [] - start = 0 - for section in sections: - slices.append( - [ - None if idx != axis else slice(start, start + section, 1) - for idx in range(len(ishape)) - ] - ) - start = start + section - - return [index_with_slices(inp, slices[i]) for i in range(len(sections))] - - -@register_lower_rule(mops.Split, "Split") -def split_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - nr_inp, nr_oup = len(ctx.vars_in), len(ctx.vars_out) - assert len(args) == nr_inp and nr_inp == nr_oup + 1 and len(args) > 1 - - assert isinstance(ctx.param["axis"], int) - axis = ctx.param["axis"] - - sections = [] - for i in range(nr_oup): - section = ctx.vars_out[i].shape[axis] - if ctx.vars_in[i + 1].bound_data is not None: - assert section == _parse_var_as_value(ctx.vars_in[i + 1]) - sections.append(section) - - return split(args[0], sections, axis) - - -def fill(value, shape, dtype): - assert isinstance(value, (int, float, bool)) - value = np.asarray(value, dtype=dtype) - return broadcast_to(HLOTensor(value, dtype=dtype), shape) - - -@register_lower_rule(mops.Fill) -def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 - assert ctx.vars_out[0].dtype == ctx.op.dtype - _check_shape(ctx.vars_out[0].shape, ctx.vars_in[0].bound_data) - value = ctx.op.value - dtype = ctx.vars_out[0].dtype - shape = ctx.vars_out[0].shape - return fill(value, shape, dtype) - - -@register_lower_rule(mops.FillLike) -def fill_like_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 - var_in, var_out, opr = ctx.vars_in[0], ctx.vars_out[0], ctx.op - value = opr.value - - assert _shape_equal(var_in.shape, var_out.shape) and _shape_equal( - var_out.shape, args[0].shape - ) - assert var_in.dtype == var_out.dtype and args[0].dtype == var_out.dtype - shape, dtype = var_out.shape, var_out.dtype - - return fill(value, shape, dtype) - - -def pad(inp, pad_value, padding): - # interior is used as dilated padding if it is not zero - assert isinstance( - pad_value, (int, float, bool, np.ndarray) - ), f"pad_value error {type(pad_value)}" - pad_value = HLOTensor(pad_value, dtype=inp.dtype) - - low, high, interior = [], [], [] - for p in padding: - assert len(p) == 3 - low.append(p[0]) - high.append(p[1]) - interior.append(p[2]) - - return HLOTensor( - hlo.PadOp( - inp.tensor, - pad_value.tensor, - ir_utils.dense_int_elements(low), - ir_utils.dense_int_elements(high), - ir_utils.dense_int_elements(interior), - ).result - ) - - -@register_lower_rule(mops.Reshape) -def reshape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 2 - return args[0].reshape(ctx.vars_out[0].shape) - - -@register_lower_rule(mops.RemoveAxis) -def remove_axis_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1 - return args[0].reshape(ctx.vars_out[0].shape) - - -@register_lower_rule(mops.AddAxis) -def add_axis_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == 1 - return args[0].reshape(ctx.vars_out[0].shape) - - -@register_lower_rule(mops.TypeCvt) -def typecvt_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - return args[0].astype(ctx.vars_out[0].dtype) - - -@register_lower_rule("TypeCvtV2") -def typecvt_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - return args[0].astype(ctx.vars_out[0].dtype) - - -@register_lower_rule(mops.Broadcast) -def broadcast_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - return args[0].broadcast_to(ctx.vars_out[0].shape) - - -@register_lower_rule(mops.Copy, mops.Identity) -def copy_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - return args diff --git a/imperative/python/megengine/xla/rules/trivial.py b/imperative/python/megengine/xla/rules/trivial.py deleted file mode 100644 index 558ad7a6c..000000000 --- a/imperative/python/megengine/xla/rules/trivial.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Sequence, Union - -import numpy as np - -from ...core._imperative_rt import ops as mops -from ..lib.mlir import ir -from .hlotensor import HLOTensor -from .utils import _check_shape, register_lower_rule - - -@register_lower_rule(mops.GetVarShape) -def get_var_shape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - if len(args) > 1: - assert len(args) == 2, f"{len(args)}" - _check_shape(args[0].shape, args[1].shape) - - shp = args[0].shape - if ctx.op.axis != 7: - shp = (shp[ctx.op.axis],) - - shp = np.array(shp, np.int64) - ctx.module_context.set_value(ctx.vars_out[0], shp) - return HLOTensor(shp) - - -@register_lower_rule("create_tensor") -def create_tensor_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): - assert len(args) == len(ctx.vars_in) == len(ctx.vars_out) == 1 - var_in, var_out = ctx.vars_in[0], ctx.vars_out[0] - if var_in.bound_data is not None: - ctx.module_context.set_value(var_in, var_in.bound_data) - ctx.module_context.set_value(var_out, var_in.bound_data) - assert var_in.shape == var_out.shape - if var_out.bound_data is not None: - data = np.asarray(var_out.bound_data, var_out.dtype) - elif var_in.bound_data is not None: - data = np.asarray(var_in.bound_data, var_out.dtype) - else: - assert False, "only support create tensor from const now" - - return HLOTensor(data) - - -@register_lower_rule("io_mark_var") -def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): - assert len(args) == 1 - return args - - -@register_lower_rule("rename") -def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): - assert len(args) == 1 - return args diff --git a/imperative/python/megengine/xla/rules/utils.py b/imperative/python/megengine/xla/rules/utils.py deleted file mode 100644 index 7634db5a5..000000000 --- a/imperative/python/megengine/xla/rules/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np - -from ..lib.mlir import ir - -lower_rule = {} - - -def register_lower_rule(*ops): - def decorator(rule): - for op in ops: - assert op not in lower_rule, f"{op} lower rule has been registered" - lower_rule[op] = rule - - def wrapper(*args, **kwargs): - return rule(*args, **kwargs) - - return wrapper - - return decorator - - -def get_rule(op): - if isinstance(op, str): - return lower_rule[op] - return lower_rule[type(op)] - - -def _log_mge_opr_attrs(mopr): - print(f"============ {mopr} ============") - for k in dir(mopr): - if not k.startswith("__"): - attr = getattr(mopr, k) - if not isinstance(attr, type): - print(f" {k}: {type(attr)} = {attr}") - - -def _shape_equal(lhs_shape, rhs_shape): - lhs_shape = lhs_shape.tolist() if isinstance(lhs_shape, np.ndarray) else lhs_shape - rhs_shape = rhs_shape.tolist() if isinstance(rhs_shape, np.ndarray) else rhs_shape - assert isinstance(lhs_shape, (tuple, list)) and isinstance( - rhs_shape, (tuple, list) - ), f"lhs_shape: {lhs_shape}{type(lhs_shape)}, rhs_shape: {rhs_shape}{type(rhs_shape)}" - if len(lhs_shape) == 0 and len(rhs_shape) == 0: - return True - if len(lhs_shape) != 0: - assert isinstance(lhs_shape[0], int) - if len(rhs_shape) != 0: - assert isinstance(rhs_shape[0], int) - - if len(lhs_shape) != len(rhs_shape): - return False - - for l, r in zip(lhs_shape, rhs_shape): - if l != r: - return False - - return True - - -def _check_shape(actual, ref): - if ref is not None: - assert _shape_equal(actual, ref), f"shape error, actual: {actual}, ref: {ref}" - - -def _check_dtype(actual, ref): - if ref is not None: - assert actual == ref, f"dtype error, actual: {actual}, ref: {ref}" - - -def unwrap_opresult_list(irnode): - if isinstance(irnode, ir.OpResultList): - if len(irnode) == 1: - return irnode[0] - return irnode - - -def _parse_var_as_value(var): - assert isinstance( - var.bound_data, (np.ndarray, int) - ), "cannot parse a non-const var as value" - if tuple(var.bound_data.shape) == (1,): - return int(var.bound_data) - else: - return var.bound_data - - -def _can_broadcast_to(src, dst, broadcast_dims=None): - if len(src) > len(dst): - return False - if broadcast_dims is None: - for i in range(-1, -len(src) - 1, -1): - if src[i] != dst[i] and src[i] != 1: - return False - else: - for idx, dim in enumerate(broadcast_dims): - if not (src[idx] == dst[dim] or src[idx] == 1): - return False - return True diff --git a/imperative/python/megengine/xla/sharding.py b/imperative/python/megengine/xla/sharding.py deleted file mode 100644 index b94fea73a..000000000 --- a/imperative/python/megengine/xla/sharding.py +++ /dev/null @@ -1,454 +0,0 @@ -# sharding annotation helper, but we do not use the sharding in megengine -# so we set all the input sharding as `Replicated` by default -import abc -import functools -import itertools as it -from typing import Optional, Sequence, Set, Tuple, Union - -import numpy as np - -from ..tensor import Parameter as MgeParameter -from ..tensor import Tensor as MgeTensor -from .device import device_put -from .dtype import _np_types, canonicalize_arg -from .lib import xla_client as xc -from .utils import safe_zip, tuple_insert, unzip3, use_cpp_class, use_cpp_method - -pmap_lib = xc._xla.pmap_lib - - -def spec_to_indices(shape, spec): - shape = (1, *shape) - return tuple(spec.indices(shape).flat) - - -@use_cpp_class(xc.Sharding) -class Sharding(metaclass=abc.ABCMeta): - @abc.abstractproperty - def device_set(self): - raise NotImplementedError("should be overrided") - - @abc.abstractmethod - def devices_indices_map(self, global_shape): - raise NotImplementedError("should be overrided") - - @abc.abstractmethod - def shard_shape(self, global_shape): - raise NotImplementedError("should be overrided") - - @abc.abstractmethod - def is_equivalent_to(self, other, ndim) -> bool: - raise NotImplementedError("should be overrided") - - @functools.cached_property - def addressable_devices(self): - return { - d for d in self.device_set if d.process_index == d.client.process_index() - } - - @functools.cached_property - def is_fully_addressable(self) -> bool: - return len(self.device_set) == len(self.addressable_devices) - - @functools.lru_cache(maxsize=4096) - def addressable_devices_indices_map(self, global_shape): - return { - d: ind - for d, ind in self.devices_indices_map(global_shape).items() - if d.process_index == d.client.process_index() - } - - -@use_cpp_class(xc.XLACompatibleSharding) -class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta): - @abc.abstractproperty - def _device_assignment(self): - raise NotImplementedError("should be overrided") - - @abc.abstractmethod - def _to_xla_op_sharding(self, num_dimensions: int): - raise NotImplementedError("should be overrided") - - @functools.lru_cache(maxsize=4096) - def devices_indices_map(self, global_shape): - op_sharding = self._to_xla_op_sharding(len(global_shape)) - op_sharding_sharding = OpShardingSharding(self._device_assignment, op_sharding) - return op_sharding_sharding.devices_indices_map(global_shape) - - @functools.cached_property - def _addressable_device_assignment(self): - return [ - d - for d in self._device_assignment - if d.process_index == d.client.process_index() - ] - -@use_cpp_class(xc.GSPMDSharding) -class OpShardingSharding(XLACompatibleSharding): - @use_cpp_method - def __init__(self, devices, op_sharding): - self._devices = tuple(devices) - self._op_sharding = op_sharding - - def __reduce__(self): - return type(self), (self._devices, self._op_sharding) - - @functools.cached_property - def _op_sharding_hash(self): - return hash(xc.HloSharding.from_proto(self._op_sharding)) - - def __eq__(self, other): - if not isinstance(other, OpShardingSharding): - return False - if id(self) == id(other): - return True - - def __hash__(self): - if not hasattr(self, "_hash"): - self._hash = hash((self._devices, self._op_sharding_hash)) - return self._hash - - def __repr__(self): - return ( - f"OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})" - ) - - @functools.cached_property - def device_set(self): - return set(self._devices) - - @property - def _device_assignment(self): - return list(self._devices) - - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: - return self._op_sharding - - @classmethod - def get_replicated(cls, device_assignment): - proto = _get_replicated_op_sharding() - return cls(device_assignment, proto) - - -@use_cpp_class(xc.SingleDeviceSharding) -class SingleDeviceSharding(XLACompatibleSharding): - @use_cpp_method - def __init__(self, device): - self._device = device - - def __reduce__(self): - return type(self), (self._device,) - - def __repr__(self): - return f"SingleDeviceSharding(device={repr(self._device)})" - - def __hash__(self): - return hash(self._device) - - def __eq__(self, other): - if not isinstance(other, SingleDeviceSharding): - return False - if id(self) == id(other): - return True - return self._device == other._device - - @property - def device_set(self): - return {self._device} - - def devices_indices_map(self, global_shape): - return {self._device: (slice(None),) * len(global_shape)} - - @property - def _device_assignment(self): - return [self._device] - - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: - return _get_replicated_op_sharding() - - -@use_cpp_class(xc.PmapSharding) -class PmapSharding(XLACompatibleSharding): - devices: np.ndarray - sharding_spec: pmap_lib.ShardingSpec - - @use_cpp_method - def __init__( - self, - devices: Union[Sequence[xc.Device], np.ndarray], - sharding_spec: pmap_lib.ShardingSpec, - ): - self.devices = np.asarray(devices) - self.sharding_spec = sharding_spec - - def __reduce__(self): - return type(self), (self.devices, self.sharding_spec) - - def __eq__(self, other): - if not isinstance(other, PmapSharding): - return False - if id(self) == id(other): - return True - return self.sharding_spec == other.sharding_spec and np.array_equal( - self.devices, other.devices - ) - - def __hash__(self): - if not hasattr(self, "_hash"): - self._hash = hash((tuple(self.devices.flat), self.sharding_spec)) - return self._hash - - def __str__(self): - device_ids = [d.id for d in self.devices.flat] - return ( - f"PmapSharding(sharding_spec={self.sharding_spec}, " - f"{device_ids=}, " - f"device_platform={self.devices.flat[0].platform.upper()}, " - f"device_shape={self.devices.shape})" - ) - - def __repr__(self): - return ( - f"PmapSharding(sharding_spec={self.sharding_spec}, " - f"devices={self.devices})" - ) - - def is_equivalent_to(self, other, ndim: int,) -> bool: - return self == other - - @functools.cached_property - def device_set(self) -> Set[xc.Device]: - return set(self.devices.flat) - - @functools.lru_cache(maxsize=4096) - def devices_indices_map(self, global_shape): - indices = spec_to_indices(global_shape, self.sharding_spec) - return dict(safe_zip(self.devices.flat, indices)) - - @functools.cached_property - def _device_assignment(self): - return list(self.devices.flat) - - def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: - raise NotImplementedError("pmap doesn't use OpSharding.") - - @functools.lru_cache(maxsize=4096) - def shard_shape(self, global_shape): - sharded_dim = None - sharded_dim_size = None - for i, s in enumerate(self.sharding_spec.sharding): - if isinstance(s, pmap_lib.Unstacked): - sharded_dim = i - sharded_dim_size = s.size - break - if sharded_dim is None: - return global_shape - if global_shape[sharded_dim] != sharded_dim_size: - raise ValueError( - f"The sharded dimension must be equal to the number of " - f"devices passed to PmapSharding. Got sharded dimension {sharded_dim} " - f"with value {global_shape[sharded_dim]} in shape {global_shape} and " - f"the number of devices={len(self._device_assignment)}" - ) - return global_shape[:sharded_dim] + global_shape[sharded_dim + 1 :] - - -def _get_op_sharding_shardings_from_executable( - xla_executable, - device_assignment: Sequence[xc.Device], - num_in_avals: int, - num_out_avals: int, -) -> Tuple[ - Sequence[XLACompatibleSharding], Sequence[XLACompatibleSharding], -]: - assert len(device_assignment) == 1 - if len(device_assignment) == 1: - return ( - [SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)], - [SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)], - ) - - -def is_op_sharding_replicated(op: xc.OpSharding) -> bool: - if len(op.tile_assignment_devices) == 1: - return True - return xc.HloSharding.from_proto(op).is_replicated() - - -class _UnspecifiedSharding: - pass - - -def _is_unspecified(x): - return isinstance(x, _UnspecifiedSharding) - - -def make_unspec_sharding(inps): - return [_UnspecifiedSharding()] * len(inps) - - -# split the tensor into sharded shape according to the sharding strategy -def sharded_val(in_val, in_sharding): - if in_sharding is None or _is_unspecified(in_sharding): - return in_val - - if in_sharding.type == xc.OpSharding.Type.REPLICATED: - return in_val - - assert False, "not implemented" - - -def _get_normalized_avals_and_shardings( - global_in_avals, in_shardings, in_is_global, -): - avals = [] - shardings = [] - - for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global): - if is_global: - aval = gaval - in_sharding = i - else: - assert False - avals.append(aval) - shardings.append(in_sharding) - - return avals, shardings - - -shard_arg_handlers = {} - - -def _shard_nparray(x, devices, indices, sharding=None): - if x.shape == (): - return device_put([x] * len(devices), devices) - return device_put([x[i] for i in indices], devices) - - -def _shard_xla_device_array(x: xc._xla.DeviceArray, devices, indices, sharding=None): - def _as_slice_indices(arr, idx): - start_indices = [0] * arr.ndim - limit_indices = list(arr.shape) - removed_dims = [] - - tuple_idx = idx if isinstance(idx, tuple) else (idx,) - for dim, sub_idx in enumerate(tuple_idx): - if isinstance(sub_idx, int): - start_indices[dim] = sub_idx - limit_indices[dim] = sub_idx + 1 - removed_dims.append(dim) - elif sub_idx == slice(None): - continue - else: - assert isinstance(sub_idx, slice), sub_idx - assert isinstance(sub_idx.start, int), sub_idx - assert isinstance(sub_idx.stop, int), sub_idx - start_indices[dim] = sub_idx.start - limit_indices[dim] = sub_idx.stop - - return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) - - start_indices, limit_indices, removed_dims = unzip3( - _as_slice_indices(x, idx) for idx in indices - ) - shards = x._multi_slice(start_indices, limit_indices, removed_dims) - return device_put(shards, devices) - - -def _shard_mge_tensor(x, devices, indices, sharding=None): - x_np = x.numpy() - if x_np.shape == (): - x_np = np.array([x_np]) - return device_put([x_np[i] for i in indices], devices) - - -for nt in _np_types: - shard_arg_handlers[nt] = _shard_nparray -shard_arg_handlers[xc._xla.DeviceArray] = _shard_xla_device_array -shard_arg_handlers[MgeTensor] = _shard_mge_tensor -shard_arg_handlers[MgeParameter] = _shard_mge_tensor - - -def shard_args(devices, indices, args, shardings=None): - def _shard_arg(arg, devices, arg_indices, sharding=None): - arg = canonicalize_arg(arg) - return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding) - - if shardings is None: - return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)] - else: - return [ - _shard_arg(arg, devices, indices[i], shardings[i]) - for i, arg in enumerate(args) - ] - - -@functools.lru_cache() -def _get_replicated_op_sharding(): - proto = xc.OpSharding() - proto.type = xc.OpSharding.Type.REPLICATED - return proto - - -def partitioned_sharding_spec( - num_partitions: int, partitions: Optional[Sequence[int]], arg_shape -): - if partitions is None: - maybe_replicate = ( - () if num_partitions == 1 else (pmap_lib.Replicated(num_partitions),) - ) - return pmap_lib.ShardingSpec( - sharding=[pmap_lib.NoSharding()] * len(arg_shape), - mesh_mapping=maybe_replicate, - ) - else: - assert len(partitions) == len(arg_shape) - return pmap_lib.ShardingSpec( - sharding=map(pmap_lib.Chunked, [[x] for x in partitions]), - mesh_mapping=map(pmap_lib.ShardedAxis, range(len(partitions))), - ) - - -def _pmap_sharding_spec( - nrep, axis_size, npart, parts, arg_shape, map_axis: Optional[int] -) -> pmap_lib.ShardingSpec: - replication_factor, ragged = divmod(nrep, axis_size) - assert not ragged - # get the sharding spec from inner sharded_jits as if we weren't in a pmap - pspec = partitioned_sharding_spec(npart, parts, arg_shape) - maybe_replicate = ( - () if replication_factor == 1 else (pmap_lib.Replicated(replication_factor),) - ) - if map_axis is not None: - sharded_in_axis = sum( - not isinstance(s, pmap_lib.NoSharding) for s in pspec.sharding[:map_axis] - ) - - def shift_sharded_axis(a): - if isinstance(a, pmap_lib.ShardedAxis) and a.axis >= sharded_in_axis: - return pmap_lib.ShardedAxis(a.axis + 1) - return a - - # replication_factor represents the product of inner pmaps, so it goes - # after the outer pmapped axis at index 0 - return pmap_lib.ShardingSpec( - sharding=tuple_insert( - pspec.sharding, map_axis, pmap_lib.Unstacked(axis_size) - ), - mesh_mapping=it.chain( - [pmap_lib.ShardedAxis(sharded_in_axis)], - maybe_replicate, - map(shift_sharded_axis, pspec.mesh_mapping), - ), - ) - else: - return pmap_lib.ShardingSpec( - sharding=pspec.sharding, - mesh_mapping=(pmap_lib.Replicated(axis_size),) - + maybe_replicate - + pspec.mesh_mapping, - ) - - -def _get_pmap_sharding(devices, specs): - return [PmapSharding(devices, spec) for spec in specs] diff --git a/imperative/python/megengine/xla/utils.py b/imperative/python/megengine/xla/utils.py deleted file mode 100644 index 2debb506a..000000000 --- a/imperative/python/megengine/xla/utils.py +++ /dev/null @@ -1,95 +0,0 @@ -import itertools as it -from functools import cached_property -from typing import cast - - -# [[a, b], [c], [d, e, f]] -> [a, b, c, d, e, f] -def flatten_list(xs): - return list(it.chain.from_iterable(xs)) - - -# [a, b, c, d, e, f], [2, 1, 3] -> [[a, b], [c], [d, e, f]] -def unflatten_list(xs, ns): - xs_iter = iter(xs) - unflattened = [[next(xs_iter) for _ in range(n)] for n in ns] - _unflatten_done = object() - assert next(xs_iter, _unflatten_done) is _unflatten_done - return unflattened - - -# [a, b, c, d, e, f], [2, 1, 3] -> [a, b], [c], [d, e, f] -def split_list(args, ns): - args = list(args) - lists = [] - for n in ns: - lists.append(args[:n]) - args = args[n:] - lists.append(args) - return lists - - -# zip with args length check -def safe_zip(*args): - args = list(map(list, args)) - n = len(args[0]) - for arg in args[1:]: - assert len(arg) == n, f"length mismatch: {list(map(len, args))}" - return list(zip(*args)) - - -# ((x, a), (y, b), (z, c)) -> (x, y, z), (a, b, c) -def unzip2(xys): - xs = [] - ys = [] - for x, y in xys: - xs.append(x) - ys.append(y) - return tuple(xs), tuple(ys) - - -def unzip3(xyzs): - xs = [] - ys = [] - zs = [] - for x, y, z in xyzs: - xs.append(x) - ys.append(y) - zs.append(z) - return tuple(xs), tuple(ys), tuple(zs) - - -def _unwrap_func(f): - if isinstance(f, property): - return cast(property, f).fget - elif isinstance(f, cached_property): - return f.func - return f - - -def use_cpp_class(cpp_cls): - def wrapper(cls): - exclude_methods = {"__module__", "__dict__", "__doc__"} - - for attr_name, attr in cls.__dict__.items(): - if attr_name not in exclude_methods and not hasattr( - _unwrap_func(attr), "_use_cpp" - ): - setattr(cpp_cls, attr_name, attr) - - cpp_cls.__doc__ = cls.__doc__ - - return cpp_cls - - return wrapper - - -def use_cpp_method(f): - original_func = _unwrap_func(f) - original_func._use_cpp = True - return f - - -# (a, b, c), 1, x -> (a, x, b, c) -def tuple_insert(t, idx, val): - assert 0 <= idx <= len(t), (idx, len(t)) - return t[:idx] + (val,) + t[idx:] -- GitLab