提交 cf139400 编写于 作者: M Megvii Engine Team

Revert "feat(imperative): add xla support"

This reverts commit 9dd518c76cde32e81d29562be4a4bc176354ee54.

GitOrigin-RevId: 5b49dab0570842c01019d333d242d36b7517fee4
上级 2d9eb3e5
import os
from ..distributed import get_rank, get_world_size, is_distributed
from .compile import MeshComputation, PmapComputation
from .device import get_xla_backend_and_device
from .distribute import initialize
from .ir_utils import DropoutMaskCanonicalizer, RngKeyAdder, TraceResult
from .lib import xla_client as xc
from .lower import lower
from .sharding import OpShardingSharding, _is_unspecified, make_unspec_sharding
xla_extention = xc._xla
xe = xla_extention
Backend = xe.Client
def build_xla(
mge_traced,
func_name=None,
device=None,
keep_unused=True,
donate_invars=None,
verbose=int(os.environ.get("MGE_VERBOSE_XLA_IR", "0")),
return_with_io=False,
return_device_array=False,
):
assert device == None, "cannot specify device now"
assert keep_unused == True, "keep_unused error"
assert donate_invars == None, "donate_invars error"
# normalize megengine trace result for lowering
tr = TraceResult(mge_traced, func_name)
tr = RngKeyAdder()(tr)
tr = DropoutMaskCanonicalizer()(tr)
if verbose and get_rank() == 0:
print("================ Mge Trace Result ================")
print(tr)
in_is_global = (True,) * len(tr.inputs)
kept_var_idx = set(range(len(tr.inputs))) if keep_unused else set()
# init for xla distributed and setup device
if is_distributed():
initialize("127.0.0.1:12345", get_world_size(), get_rank(), [get_rank()])
backend, device_assignment, platform = get_xla_backend_and_device(device)
module, keepalive, host_callbacks = lower(
tr, backend, platform, None, None, donate_invars,
)
if not is_distributed():
# setup sharding information
in_shardings = make_unspec_sharding(tr.inputs)
out_shardings = make_unspec_sharding(tr.outputs)
in_shardings = tuple(
OpShardingSharding.get_replicated(device_assignment)
if _is_unspecified(i)
else i
for i in in_shardings
)
computation = MeshComputation(
tr.func_name,
module,
donated_invars=donate_invars,
trace_result=tr,
mesh=None,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=False,
tuple_args=False, # for tpu
in_is_global=in_is_global,
auto_spmd_lowering=False,
unordered_effects=[],
ordered_effects=[],
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
device_assignment=device_assignment,
committed=False, # unknown
pmap_nreps=1,
return_device_array=return_device_array,
)
else:
computation = PmapComputation(
tr.func_name,
module,
trace_result=tr,
unordered_effects=[],
ordered_effects=[],
tuple_args=False, # for tpu
in_is_global=in_is_global,
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
devices=None,
return_device_array=return_device_array,
world_size=get_world_size(),
rank=get_rank(),
)
if verbose and get_rank() == 0:
print("================ XLA HLO IR ================")
print(computation.as_text())
compiled = computation.compile()
if verbose and get_rank() == 0:
print("================ XLA Execute Plan ================")
print(compiled.as_text())
ret = compiled.unsafe_call
if return_with_io:
return ret, tr.inputs, tr.outputs
return ret
此差异已折叠。
import itertools as it
from typing import Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt.common import CompNode
from ..tensor import Parameter as MgeParameter
from ..tensor import Tensor as MgeTensor
from .dtype import (
_np_types,
_python_scalar_dtypes,
_scalar_type_to_dtype,
canonicalize_arg,
)
from .lib import xla_bridge as xb
from .lib import xla_client as xc
from .utils import safe_zip
xla_extention = xc._xla
xe = xla_extention
Backend = xe.Client
device_put_handlers = {}
def _device_put_nparray(x, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(x, device),)
def _device_put_scalar(x, device):
def cvt_scalar_to_nparray(x, dtype=None):
if dtype is None and type(x) in _python_scalar_dtypes:
dtype = _scalar_type_to_dtype(type(x), x)
return np.asarray(x, dtype)
return _device_put_nparray(cvt_scalar_to_nparray(x), device)
def _device_put_device_array(x, device):
assert False
def _device_put_mge_tensor(x, device):
x = x.numpy()
return _device_put_nparray(x, device)
for nt in _np_types:
device_put_handlers[nt] = _device_put_nparray
for sc in _python_scalar_dtypes:
device_put_handlers[nt] = _device_put_scalar
device_put_handlers[xc._xla.DeviceArray] = _device_put_device_array
device_put_handlers[MgeTensor] = _device_put_mge_tensor
device_put_handlers[MgeParameter] = _device_put_mge_tensor
def _device_put_impl(x, device):
x = canonicalize_arg(x)
return device_put_handlers[type(x)](x, device)
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool = False):
if replicate:
return list(
it.chain.from_iterable(_device_put_impl(x, device) for device in devices)
)
else:
return list(
it.chain.from_iterable(
_device_put_impl(val, device) for val, device in safe_zip(x, devices)
)
)
def get_xla_backend_and_device(device=None) -> Tuple[Backend, Sequence[xc.Device]]:
assert device is None, "device assignment is not supported yet"
device_assignment = [xb.local_devices()[0]]
backend = xb.get_device_backend(device_assignment[0])
platform = backend.platform
platform = xb.canonicalize_platform(platform)
assert xb.is_known_platform(platform), f"{platform} is not known yet"
assert platform == "cuda", f"only cuda platfrom is supportted, but get {platform}"
return backend, device_assignment, platform
import atexit
from typing import Any, Optional, Sequence, Union
from .lib import xla_client as xc
xla_extention = xc._xla
xe = xla_extention
class State:
process_id: int = 0
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
visible_devices: Optional[str] = "all"
def initialize(
self,
coordinator_address: str,
num_processes: int,
process_id: int,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
):
if local_device_ids is None:
local_device_ids = [process_id]
elif isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
else:
local_device_ids = list(local_device_ids)
assert local_device_ids == [process_id], f"{local_device_ids} .vs {process_id}"
self.visible_devices = ",".join(str(x) for x in local_device_ids)
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError("distributed.initialize should only be called once.")
self.service = xe.get_distributed_runtime_service(
coordinator_address, num_processes, use_coordination_service=True
)
if self.client is not None:
raise RuntimeError("distributed.initialize should only be called once.")
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xe.get_distributed_runtime_client(
coordinator_address,
process_id,
use_coordination_service=True,
init_timeout=300,
)
self.client.connect()
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None
if self.preemption_sync_manager:
self.preemption_sync_manager = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
"Preemption sync manager should only be initialized once."
)
self.preemption_sync_manager = xe.create_preemption_sync_manager()
self.preemption_sync_manager.initialize(self.client)
global_state = State()
def initialize(
coordinator_address: str,
num_processes: int,
process_id: int,
local_device_ids: Optional[Union[int, Sequence[int]]] = None,
):
global_state.initialize(
coordinator_address, num_processes, process_id, local_device_ids
)
atexit.register(shutdown)
def shutdown():
global_state.shutdown()
from functools import lru_cache, partial
import numpy as np
from ..tensor import Parameter as MgeParameter
from ..tensor import Tensor as MgeTensor
from .lib import xla_client as xc
_python_scalar_dtype_to_npdtypes = {
bool: np.dtype("bool"),
int: np.dtype("int64"),
float: np.dtype("float64"),
complex: np.dtype('complex128'),
}
_python_scalar_dtypes = list(_python_scalar_dtype_to_npdtypes.keys())
bfloat16 = xc.bfloat16
_bfloat16_dtype = np.dtype(bfloat16)
_float_types = [
_bfloat16_dtype,
np.dtype("float16"),
np.dtype("float32"),
np.dtype("float64"),
]
_numpy_scalar_types = {
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.complex64,
np.complex128,
np.bool_,
np.longlong,
np.intc,
} | set(np.dtype(dt).type for dt in _float_types)
_np_types = {np.ndarray} | _numpy_scalar_types
_dtype_to_32bit_dtype = {
np.dtype("int64"): np.dtype("int32"),
np.dtype("uint64"): np.dtype("uint32"),
np.dtype("float64"): np.dtype("float32"),
np.dtype('complex128'): np.dtype('complex64'),
}
def _scalar_type_to_dtype(typ, value):
dtype = canonicalize_dtype(_python_scalar_dtype_to_npdtypes[typ])
if typ is int and value is not None:
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
return dtype
# do not enable x64 because megengine only support x32
@lru_cache(maxsize=None)
def canonicalize_dtype(dtype, x64_enabled=False, allow_opaque_dtype=False):
assert allow_opaque_dtype == False and x64_enabled == False
try:
dtype_ = np.dtype(dtype)
except TypeError as e:
raise TypeError(f"dtype {dtype!r} not understood") from e
if x64_enabled:
return dtype_
else:
return _dtype_to_32bit_dtype.get(dtype_, dtype_)
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, canonicalize_dtype(x.dtype))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(x, canonicalize_dtype(_scalar_type_to_dtype(typ, x)))
def _canonicalize_mgetensor_dtype(x: MgeTensor):
canonicalized = canonicalize_dtype(x.dtype)
if canonicalized != x.dtype:
return x.astype(canonicalized)
return x
canonicalize_args_handlers = {}
canonicalize_args_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in _numpy_scalar_types
)
canonicalize_args_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_args_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _python_scalar_dtypes
)
canonicalize_args_handlers[MgeTensor] = _canonicalize_mgetensor_dtype
canonicalize_args_handlers[MgeParameter] = _canonicalize_mgetensor_dtype
def canonicalize_arg(x):
typ = type(x)
handler = canonicalize_args_handlers.get(typ)
if handler:
return handler(x)
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
import io
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Callable, Dict, Sequence, Tuple
import numpy as np
from ..core._imperative_rt import ops as mops
from ..core._imperative_rt.core2 import OpInfo, VarInfo
from . import dtype
from .lib.mlir import ir
from .lib.mlir.dialects import hlo
func_id = 0
def _default_func_name():
global func_id
func_id += 1
return f"please_realize_func_name_system_{func_id}"
def _is_rng_op(opr):
return isinstance(
opr,
(
mops.Dropout,
mops.BetaRNG,
mops.GammaRNG,
mops.GaussianRNG,
mops.PermutationRNG,
mops.PoissonRNG,
mops.ShuffleRNG,
mops.UniformRNG,
),
)
class AbstractVar:
def __init__(self, _id, _shape, _dtype) -> None:
self.id = _id
self.shape = _shape
self.dtype = _dtype
self.bound_data = None
class Pass(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def __call__(self, tr) -> Any:
pass
# because xla pass key as a tensor, while mge pass key as a param, so we need to add a
# rng key tensor to the graph and set it as the input of the graph and rng op
class RngKeyAdder(Pass):
def __call__(self, tr) -> Any:
has_rng_opr = False
for eqn in tr.eqns:
if _is_rng_op(eqn.op):
has_rng_opr = True
break
if not has_rng_opr:
return tr
# it should be [2, np.uint64], however, megengine donot support np.uint64/np.int64/np.uint32
inp_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32))
tr.add_input(inp_rng_state_var)
new_eqns = []
for eqn in tr.eqns:
if not _is_rng_op(eqn.op):
new_eqns.append(eqn)
continue
oup_rng_state_var = AbstractVar(tr.next_vid, [2, 2], np.dtype(np.int32))
tr.add_var(oup_rng_state_var)
inputs, outputs = list(eqn.inputs), list(eqn.outputs)
inputs.append(inp_rng_state_var.id)
outputs.append(oup_rng_state_var.id)
new_eqn = OpInfo(eqn.op, inputs, outputs, eqn.id, eqn.kind)
new_eqns.append(new_eqn)
inp_rng_state_var = oup_rng_state_var
tr.eqns = new_eqns
tr.set_var_as_oup(inp_rng_state_var)
return tr
# in megengine, dropout return a bit-mask while xla hard to represent, so we let xla
# return a uint8 mask, which means the mask is 8 times larger than mge
class DropoutMaskCanonicalizer(Pass):
def __call__(self, tr) -> Any:
for eqn in tr.eqns:
if not isinstance(eqn.op, mops.Dropout):
continue
outputs = list(eqn.outputs)
mask_var = tr.vars[outputs[1]]
new_mask_var = AbstractVar(
mask_var.id, (int(np.prod(mask_var.shape)) * 8,), mask_var.dtype
)
tr.vars[mask_var.id] = new_mask_var
return tr
class TraceResult:
def __init__(self, traced, func_name=None) -> None:
self.func_name = func_name if func_name is not None else _default_func_name()
self.traced = traced
self.eqns = []
self.vars = {}
self.inputs = []
self.outputs = []
self.consts = []
self.custom_vid = 0
self.effects = []
for var in self.traced.vars:
self.add_var(var)
self.custom_vid = max(var.id + 1, self.custom_vid)
if var.kind == "external" and var.inp_mark:
self.inputs.append(var.id)
if var.data_required:
self.outputs.append(var.id)
if var.kind == "const":
self.consts.append(var.id)
for op in self.traced.ops:
self.eqns.append(op)
@property
def _var_inputs(self):
return [self.vars[i] for i in self.inputs]
@property
def _var_outputs(self):
return [self.vars[i] for i in self.outputs]
@property
def _var_consts(self):
return [self.vars[i] for i in self.consts]
@property
def next_vid(self):
ret = self.custom_vid
self.custom_vid += 1
return ret
def add_var(self, var):
assert var.id not in self.vars
self.vars[var.id] = var
def add_input(self, inp_var):
self.add_var(inp_var)
self.inputs.append(inp_var.id)
def set_var_as_oup(self, oup_var):
assert oup_var.id in self.vars
self.outputs.append(oup_var.id)
def get_var(self, idx):
assert isinstance(idx, int)
return self.vars[idx]
def is_input(self, var):
if isinstance(var, int):
var = self.vars[var]
return var.kind == "external"
def is_output(self, var):
if isinstance(var, int):
var = self.vars[var]
return var.data_required
def _str_var(self, var):
def _str_shape(shp):
return "x".join([str(d) for d in shp])
dtype_to_str = {
"float16": "f16",
"float32": "f32",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint32": "u32",
"uint64": "u64",
"bool": "i1-bool",
}
if isinstance(var, int):
var = self.vars[var]
var_dtype = None
try:
var_dtype = dtype_to_str[str(var.dtype)]
except RuntimeError:
var_dtype = "unknown"
var_bound_data = (
("," + ",".join(str(var.bound_data).split()))
if var.bound_data is not None and var.bound_data.size < 5
else ""
)
return f"{var.id}%:<{_str_shape(var.shape)},{var_dtype}{var_bound_data}>"
def _str_eqn(self, eqn):
inps = ", ".join(map(self._str_var, eqn.inputs))
oups = ", ".join(map(self._str_var, eqn.outputs))
str_op = str(eqn.op)
if isinstance(eqn.op, mops.Reduce):
assert str(eqn.op.mode).startswith("Reduce.Mode.")
str_op = str_op + str(eqn.op.mode)[len("Reduce.Mode.") :]
ret = f"{oups} = {str_op}({inps})"
return ret
def __str__(self) -> str:
func_inps_str = ", ".join(map(self._str_var, self.inputs))
func_oups_str = ", ".join(map(self._str_var, self.outputs))
func_const_str = "\n ".join(map(self._str_var, self.consts))
ret = f"{self.func_name}({func_inps_str}) -> ({func_oups_str}) {{\n "
if len(self.consts) > 0:
ret += f"const:\n {func_const_str}\n "
ret += "\n ".join(map(self._str_eqn, self.eqns))
ret += "\n}"
return ret
_dtype_to_ir_type: Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(dtype.bfloat16): ir.BF16Type.get,
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def mge_dtype_to_ir_type(mge_dtype):
mge_dtype = np.dtype(mge_dtype)
assert isinstance(
mge_dtype, np.dtype
), f"arg should be numpy dtype, but is {mge_dtype}"
ir_type_factory = _dtype_to_ir_type[mge_dtype]
return ir_type_factory()
def mge_varinfo_to_ir_type(mge_varinfo):
assert isinstance(mge_varinfo, (VarInfo, AbstractVar)), "args should be VarInfo"
shape = mge_varinfo.shape
return ir.RankedTensorType.get(shape, mge_dtype_to_ir_type(mge_varinfo.dtype))
def mge_varinfo_to_ir_type_tuple(mge_varinfo):
return (mge_varinfo_to_ir_type(mge_varinfo),)
def make_ir_type_according_meta(src_shape: Tuple, src_dtype: np.dtype):
return ir.RankedTensorType.get(src_shape, mge_dtype_to_ir_type(src_dtype))
def make_ir_type_according_meta_tuple(src_shape: Tuple, src_dtype: np.dtype):
return (make_ir_type_according_meta(src_shape, src_dtype),)
_constant_handlers = {}
def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]:
if canonicalize_types:
x = np.asarray(x, dtype.canonicalize_dtype(x.dtype))
element_type = mge_dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
nelems = x.size
x = np.packbits(x, bitorder="little")
if nelems == 1:
x = np.array(0 if x.item() == 0 else 0xFF, np.uint8)
elif x.dtype == dtype.bfloat16:
x = x.view(np.uint16)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (hlo.ConstantOp(attr).result,)
def _ndarray_constant_handler(
val: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
if np.any(np.equal(0, val.strides)) and val.size > 0:
(zero_stride_axes,) = np.where(np.equal(0, val.strides))
(other_axes,) = np.where(np.not_equal(0, val.strides))
collapsed_val = val[
tuple(
0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim)
)
]
if canonicalize_types:
collapsed_val = np.asarray(
collapsed_val, dtype.canonicalize_dtype(collapsed_val.dtype)
)
out = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(
val.shape, mge_dtype_to_ir_type(collapsed_val.dtype)
),
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
dense_int_elements(other_axes),
).result
return (out,)
else:
return _numpy_array_constant(val, canonicalize_types)
_constant_handlers[np.ndarray] = _ndarray_constant_handler
for _scalar_type in [
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
np.bool_,
np.longlong,
dtype.bfloat16,
]:
_constant_handlers[_scalar_type] = _ndarray_constant_handler
def _python_scalar_constant_handler(dtype, val, canonicalize_dtypes):
return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)
for pt, dt in dtype._python_scalar_dtype_to_npdtypes.items():
_constant_handlers[pt] = partial(_python_scalar_constant_handler, dt)
def _mge_varinfo_constant_handler(val, canonicalize_dtypes):
assert isinstance(val, VarInfo)
assert val.bound_data is not None and val.kind == "const"
assert isinstance(val.bound_data, np.ndarray)
return _numpy_array_constant(
np.asarray(val.bound_data, val.dtype), canonicalize_dtypes
)
_constant_handlers[VarInfo] = _mge_varinfo_constant_handler
def ir_constant_tuple(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]:
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler:
out = handler(val, canonicalize_types)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
return out
assert False
def ir_constant(val: Any, canonicalize_types: bool = True) -> Sequence[ir.Value]:
values = ir_constant_tuple(val, canonicalize_types=canonicalize_types)
assert len(values) == 1
return values[0]
def token_type() -> Sequence[ir.Type]:
return [hlo.TokenType.get()]
def dummy_token_type_tuple() -> Sequence[ir.Type]:
return make_ir_type_according_meta_tuple((0,), np.bool_)
def dummy_token() -> Sequence[ir.Value]:
return ir_constant_tuple(np.zeros(0, np.bool_))
def i32_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
def ui64_attr(i):
return ir.IntegerAttr.get(ir.IntegerType.get_unsigned(64), i)
def f32_attr(i):
return ir.FloatAttr.get(ir.F32Type.get(), i)
def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr:
lhs_prec = str(lhs_prec)
rhs_prec = str(rhs_prec)
assert lhs_prec == "float32"
assert rhs_prec == "float32"
dtype_to_precision = {
"float32": "DEFAULT",
}
precision = (dtype_to_precision[lhs_prec], dtype_to_precision[rhs_prec])
return ir.ArrayAttr.get([hlo.PrecisionAttr.get(p) for p in precision])
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a = np.packbits(np.array(xs, np.bool_), bitorder="little")
if len(xs) == 1:
a = np.array(0 if a.item() == 0 else 0xFF, np.uint8)
return ir.DenseElementsAttr.get(
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)]
)
def get_irnode_shape(irnode):
if isinstance(irnode, (list, tuple, ir.OpResultList)):
assert len(irnode) == 1
irnode = irnode[0]
assert isinstance(irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult))
if not isinstance(irnode, ir.RankedTensorType):
irnode = ir.RankedTensorType(irnode.type)
return tuple(irnode.shape)
def get_irnode_dtype(irnode):
if isinstance(irnode, (list, tuple, ir.OpResultList)):
assert len(irnode) == 1
irnode = irnode[0]
assert isinstance(
irnode, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult)
), type(irnode)
if not isinstance(irnode, ir.RankedTensorType):
irnode = ir.RankedTensorType(irnode.type)
etype = irnode.element_type
for k, v in _dtype_to_ir_type.items():
if etype == v():
return k
assert False, f"unknown irnode {irnode}"
def module_to_string(module: ir.Module) -> str:
output = io.StringIO()
module.operation.print(
file=output, enable_debug_info=True, print_generic_op_form=False
)
return output.getvalue()
def module_to_bytecode(module: ir.Module) -> bytes:
output = io.BytesIO()
module.operation.write_bytecode(file=output)
return output.getvalue()
import os
import platform
import re
import warnings
from typing import Optional, Tuple
import jaxlib.cpu_feature_guard as cpu_feature_guard
import jaxlib.ducc_fft as ducc_fft
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.lapack as lapack
import jaxlib.xla_client as xla_client
try:
import jaxlib as jaxlib
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"megengine with xla requires jaxlib to be installed."
) from err
# some version check code
"""
import jax.version
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
import jaxlib.version
except Exception as err:
# jaxlib is too old to have version number.
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
raise ImportError(msg) from err
# Checks the jaxlib version before importing anything else from jaxlib.
# Returns the jaxlib version string.
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
minimum_jaxlib_version: str):
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
def _parse_version(v: str) -> Tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse jaxlib version '{v}'")
return tuple(int(x) for x in m.group(0).split('.'))
_jax_version = _parse_version(jax_version)
_minimum_jaxlib_version = _parse_version(minimum_jaxlib_version)
_jaxlib_version = _parse_version(jaxlib_version)
if _jaxlib_version < _minimum_jaxlib_version:
msg = (f'jaxlib is version {jaxlib_version}, but this version '
f'of jax requires version >= {minimum_jaxlib_version}.')
raise RuntimeError(msg)
if _jaxlib_version > _jax_version:
msg = (f'jaxlib version {jaxlib_version} is newer than and '
f'incompatible with jax version {jax_version}. Please '
'update your jax and/or jaxlib packages.')
raise RuntimeError(msg)
return _jaxlib_version
version_str = jaxlib.version.__version__
version = check_jaxlib_version(
jax_version=jax.version.__version__,
jaxlib_version=jaxlib.version.__version__,
minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
"""
# Before importing any C compiled modules from jaxlib, first import the CPU
# feature guard module to verify that jaxlib was compiled in a way that only
# uses instructions that are present on this machine.
cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
# branch on the Jax github.
xla_extension_version = getattr(xla_client, "_version", 0)
# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None # type: ignore
# TODO: check if we need the same for rocm.
cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):
cuda_path = None
transfer_guard_lib = xla_client._xla.transfer_guard_lib
此差异已折叠。
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.ml_program as ml_program
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
import jaxlib.mlir.dialects.stablehlo as stablehlo
import jaxlib.xla_client as xla_client
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
use_stablehlo = xla_client.mlir_api_version >= 42
if use_stablehlo:
import jaxlib.mlir.dialects.stablehlo as hlo
else:
import jaxlib.mlir.dialects.mhlo as hlo # type: ignore[no-redef]
import logging
import os
import platform as py_platform
import threading
import warnings
from functools import lru_cache, partial
from typing import Any, Dict, List, Optional, Union
import numpy as np
from jaxlib import xla_client
from ..lib import cuda_path
from .config import bool_env, config, flags, int_env
XlaBackend = xla_client._xla.Client
ShardedBuffer = Any
FLAGS = flags.FLAGS
logger = logging.getLogger(__name__)
flags.DEFINE_string(
"jax_xla_backend", "", "Deprecated, please use --jax_platforms instead."
)
flags.DEFINE_string(
"jax_backend_target",
os.getenv("JAX_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.',
)
# TODO: warn when this is used once we test out --jax_platforms a bit
flags.DEFINE_string(
"jax_platform_name",
os.getenv("JAX_PLATFORM_NAME", "").lower(),
"Deprecated, please use --jax_platforms instead.",
)
flags.DEFINE_bool(
"jax_disable_most_optimizations",
bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.",
)
flags.DEFINE_integer(
"jax_xla_profile_version",
int_env("JAX_XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.",
)
flags.DEFINE_string(
"jax_cuda_visible_devices",
"all",
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
flags.DEFINE_string(
"jax_rocm_visible_devices",
"all",
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape=[],
auto_spmd_partitioning_mesh_ids=[],
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = (
auto_spmd_partitioning_mesh_shape
)
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
if device_assignment is not None:
logger.debug(
"get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s",
num_replicas,
num_partitions,
device_assignment,
)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = "device_assignment does not match num_replicas: {} vs {}."
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = "device_assignment does not match num_partitions: {} vs {}."
raise ValueError(msg.format(device_assignment, num_partitions))
if device_assignment.dtype == object:
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
device_assignment
)
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
debug_options = compile_options.executable_build_options.debug_options
if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.jax_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version
return compile_options
# Backends, in increasing order of preference.
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
_default_backend = None
_backends: Dict[str, Any] = {}
_backends_errors: Dict[str, str] = {}
_backend_lock = threading.Lock()
def register_backend_factory(name, factory, *, priority=0):
with _backend_lock:
if name in _backends:
raise RuntimeError(f"Backend {name} already initialized")
_backend_factories[name] = (factory, priority)
register_backend_factory(
"interpreter", xla_client.make_interpreter_client, priority=-100
)
register_backend_factory(
"cpu", partial(xla_client.make_cpu_client, use_tfrt=True), priority=0
)
def make_gpu_client(*, platform_name, visible_devices_flag):
from ..distribute import global_state
visible_devices = global_state.visible_devices
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
else:
allowed_devices = None
return xla_client.make_gpu_client(
distributed_client=global_state.client,
node_id=global_state.process_id,
platform_name=platform_name,
allowed_devices=allowed_devices,
)
if hasattr(xla_client, "make_gpu_client"):
register_backend_factory(
"cuda",
partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag="jax_cuda_visible_devices",
),
priority=200,
)
register_backend_factory(
"rocm",
partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag="jax_rocm_visible_devices",
),
priority=200,
)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory(
"plugin", xla_client.make_plugin_device_client, priority=400
)
_platform_aliases = {
"cuda": "gpu",
"rocm": "gpu",
}
_alias_to_platforms: Dict[str, List[str]] = {}
for _platform, _alias in _platform_aliases.items():
_alias_to_platforms.setdefault(_alias, []).append(_platform)
def is_known_platform(platform: str):
# A platform is valid if there is a registered factory for it. It does not
# matter if we were unable to initialize that platform; we only care that
# we've heard of it and it isn't, e.g., a typo.
return platform in _backend_factories.keys() or platform in _platform_aliases.keys()
def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms = _alias_to_platforms.get(platform, None)
if platforms is None:
return platform
b = backends()
for p in platforms:
if p in b.keys():
return p
raise RuntimeError(
f"Unknown backend: '{platform}' requested, but no "
f"platforms that are instances of {platform} are present. "
"Platforms present are: " + ",".join(b.keys())
)
def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])
def is_gpu(platform):
return platform in ("cuda", "rocm")
def backends():
global _backends
global _backends_errors
global _default_backend
with _backend_lock:
if _backends:
return _backends
if config.jax_platforms:
jax_platforms = config.jax_platforms.split(",")
platforms = []
# Allow platform aliases in the list of platforms.
for platform in jax_platforms:
platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities)
else:
platforms_and_priorites = (
(platform, priority)
for platform, (_, priority) in _backend_factories.items()
)
default_priority = -1000
if hasattr(xla_client, "maybe_load_pjrt_plugins"):
xla_client.maybe_load_pjrt_plugins()
for platform, priority in platforms_and_priorites:
try:
backend = _init_backend(platform)
_backends[platform] = backend
if priority > default_priority:
_default_backend = backend
default_priority = priority
except Exception as err:
if platform in ("cpu", "interpreter"):
# We always expect the CPU and interpreter backends to initialize
# successfully.
raise
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
logger.info(err_msg)
continue
# We don't warn about falling back to CPU on Mac OS, because we don't
# support anything else there at the moment and warning would be pointless.
if (
py_platform.system() != "Darwin"
and _default_backend.platform == "cpu"
and FLAGS.jax_platform_name != "cpu"
):
logger.warning(
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
return _backends
def _clear_backends():
global _backends
global _backends_errors
global _default_backend
logger.info("Clearing JAX backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
_default_backend = None
get_backend.cache_clear()
def _init_backend(platform):
factory, unused_priority = _backend_factories.get(platform, (None, None))
if factory is None:
raise RuntimeError(f"Unknown backend '{platform}'")
logger.debug("Initializing backend '%s'", platform)
backend = factory()
if backend is None:
raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.")
logger.debug("Backend '%s' initialized", platform)
return backend
def _get_backend_uncached(platform=None):
if not isinstance(platform, (type(None), str)):
return platform
platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None
bs = backends()
if platform is not None:
platform = canonicalize_platform(platform)
backend = bs.get(platform, None)
if backend is None:
if platform in _backends_errors:
raise RuntimeError(
f"Backend '{platform}' failed to initialize: "
f"{_backends_errors[platform]}"
)
raise RuntimeError(f"Unknown backend {platform}")
return backend
else:
return _default_backend
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
return _get_backend_uncached(platform)
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
if device is not None:
return device.client
return get_backend()
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count())
def devices(
backend: Optional[Union[str, XlaBackend]] = None
) -> List[xla_client.Device]:
"""Returns a list of all devices for a given backend.
.. currentmodule:: jaxlib.xla_extension
Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
otherwise ``'cpu'``.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
return get_backend(backend).devices()
def default_backend() -> str:
"""Returns the platform name of the default XLA backend."""
return get_backend(None).platform
def local_devices(
process_index: Optional[int] = None,
backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None,
) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-process
platforms though.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Integer process index.
"""
return get_backend(backend).process_index()
# returns the number of mge processes associated with the backend
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return max(d.process_index for d in devices(backend)) + 1
import dataclasses
import itertools
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt.core2 import OpInfo, VarInfo
from . import utils
from .device import xb
from .ir_utils import (
TraceResult,
ir_constant_tuple,
mge_varinfo_to_ir_type_tuple,
)
from .lib import xla_client as xc
from .lib.mlir import dialects, ir
from .lib.mlir.dialects import func as func_dialect
from .rules import get_rule
from .rules.hlotensor import HLOTensor
from .rules.utils import _shape_equal
from .sharding import sharded_val
def make_ir_context() -> ir.Context:
context = ir.Context()
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.stablehlo.register_dialect(context)
return context
@dataclasses.dataclass
class ModuleContext:
context: ir.Context
module: ir.Module
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: Optional[Union[str, xb.XlaBackend]]
platform: str
keepalives: List[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
# Stores the value of varinfo that can be inferred in lowering process
inferred_values: Dict[VarInfo, np.ndarray]
def __init__(
self,
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
keepalives: List[Any] = [],
host_callbacks: List[Any] = [],
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None,
):
assert platform is not None
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
self.backend_or_name = backend_or_name
self.platform = platform
self.keepalives = keepalives
self.host_callbacks = host_callbacks
self.inferred_values = {}
@property
def backend(self) -> xb.XlaBackend:
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
return xb.get_backend(self.backend_or_name)
return self.backend_or_name
def replace(self, **kw):
return dataclasses.replace(self, **kw)
def get_value(self, varinfo):
assert varinfo in self.inferred_values
return self.inferred_values[varinfo]
def set_value(self, varinfo, value):
self.inferred_values[varinfo] = value
@dataclasses.dataclass
class LoweringRuleContext:
module_context: ModuleContext
op: OpInfo
vars_in: Sequence[VarInfo]
vars_out: Sequence[VarInfo]
param: Dict = None
def replace(self, **kw):
return dataclasses.replace(self, **kw)
def _unwrap_singleton_ir_values(x):
return x[0] if len(x) == 1 else x
def _wrap_singleton_ir_values(
x: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[ir.Value]:
return (x,) if isinstance(x, ir.Value) else tuple(x)
def lowering_ops(
ctx: ModuleContext, trace_result: TraceResult, *args: Sequence[ir.Value],
):
# var_id -> ir.Value
env: Dict[int, Tuple[ir.Value, ...]] = {}
consts = list(map(ir_constant_tuple, trace_result._var_consts))
# read ir.Values from env according to var_ids
def read(var_ids):
assert isinstance(var_ids, (list, tuple))
ret = []
for vid in var_ids:
assert isinstance(vid, int)
ret.append(env[vid])
return ret
# update env with var_ids and ir.Values
def write(var_ids, hlo_nodes):
assert isinstance(var_ids, (list, tuple))
assert isinstance(hlo_nodes, (map, list, tuple))
hlo_nodes = list(hlo_nodes)
assert len(var_ids) == len(hlo_nodes), (len(var_ids), len(hlo_nodes))
for vid, node in zip(var_ids, hlo_nodes):
assert vid not in env
env[vid] = node
assert len(args) == len(trace_result.inputs)
assert len(consts) == len(trace_result.consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs)
# initialize env with inputs and consts
write(trace_result.inputs, args)
write(trace_result.consts, consts)
for eqn in trace_result.eqns:
rule_ctx = LoweringRuleContext(
module_context=ctx,
op=eqn.op,
vars_in=[trace_result.vars[inp] for inp in eqn.inputs],
vars_out=[trace_result.vars[oup] for oup in eqn.outputs],
param=eqn.param,
)
rule = get_rule(eqn.op)
in_nodes = read(eqn.inputs)
hinps = [
HLOTensor(irval, var.shape, var.dtype)
for var, irval in zip(
rule_ctx.vars_in, map(_unwrap_singleton_ir_values, in_nodes)
)
]
houps = rule(rule_ctx, *hinps)
if isinstance(houps, HLOTensor):
houps = [houps]
out_nodes = []
for out_id, hlo_out in zip(eqn.outputs, houps):
var_out = trace_result.vars[out_id]
assert _shape_equal(
var_out.shape, hlo_out.shape
), f"{eqn.op}: {var_out.shape} != {hlo_out.shape}"
out_nodes.append(hlo_out.tensor)
out_nodes = tuple(map(_wrap_singleton_ir_values, out_nodes))
write(eqn.outputs, out_nodes)
return read(trace_result.outputs)
def make_xla_graph(
ctx: ModuleContext,
name: str,
trace_result: TraceResult,
public: bool = True,
in_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
out_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
) -> func_dialect.FuncOp:
assert public is True, "do not process the visibitity of function"
assert (
in_shardings is None and out_shardings is None
), "sharding when lowering is not supported yet"
assert (
input_output_aliases is None or input_output_aliases == []
), "donated inputs are not supported yet"
input_types = [
mge_varinfo_to_ir_type_tuple(trace_result.vars[idx])
for idx in trace_result.inputs
]
output_types = [
mge_varinfo_to_ir_type_tuple(trace_result.vars[idx])
for idx in trace_result.outputs
]
flat_input_types = utils.flatten_list(input_types)
flat_output_types = utils.flatten_list(output_types)
assert len(flat_input_types) == len(trace_result.inputs)
assert len(flat_output_types) == len(trace_result.outputs)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private"
)
ctx.symbol_table.insert(func_op)
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
unflattened_args = utils.unflatten_list(flat_args, map(len, input_types))
outs = lowering_ops(ctx, trace_result, *unflattened_args)
flat_oups = utils.flatten_list(outs)
func_dialect.ReturnOp(flat_oups)
return func_op
def lower(
trace_result: TraceResult,
backend,
platform,
in_shardings=None,
out_shardings=None,
donated_invars=None,
):
assert donated_invars is None, "donated inputs are not supported yet"
assert trace_result.effects == [], "effect of trace is not supported"
if in_shardings is not None:
trace_result.inputs = [
sharded_val(inp, in_sharding)
for inp, in_sharding in zip(trace_result.inputs, in_shardings)
]
if out_shardings is not None:
trace_result.outputs = [
sharded_val(outp, out_sharding)
for outp, out_sharding in zip(trace_result.outputs, out_shardings)
]
ctx = ModuleContext(backend, platform)
with ctx.context, ir.Location.unknown(ctx.context):
module_name = trace_result.func_name
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
assert trace_result.effects == [], "effect of trace is not supported"
make_xla_graph(
ctx,
"main",
trace_result,
public=True,
in_shardings=None,
out_shardings=None,
input_output_aliases=[],
)
return ctx.module, ctx.keepalives, ctx.host_callbacks
from . import (
communicate,
elemwise,
indexing,
math,
nn,
normalize,
random,
reduction,
tensor,
trivial,
)
from .utils import get_rule
import itertools
from functools import partial
from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..lib.mlir import ir
from ..lib.mlir.dialects import hlo
from .hlotensor import HLOTensor
from .tensor import concat, split
from .utils import register_lower_rule
@register_lower_rule(mops.ParamPackConcat)
def parampack_concat_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
flattened = []
for arg, var_in in zip(args[:-1], ctx.vars_in[:-1]):
ishape_1d = (int(np.prod(var_in.shape)),)
flattened.append(arg.reshape(ishape_1d))
concated = concat(flattened, 0)
return concated
@register_lower_rule(mops.ParamPackSplit)
def parampack_split_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
offsets, shapes, var_outs = ctx.op.offsets, ctx.op.shapes, ctx.vars_out
assert (len(offsets) // 2) == len(shapes) == len(var_outs), "error params"
for var_out, shape in zip(var_outs, shapes):
assert tuple(var_out.shape) == tuple(shape), f"{var_out.shape} .vs {shape}"
sections = [np.prod(shape) for shape in shapes]
for i, section in enumerate(sections):
assert section == offsets[2 * i + 1] - offsets[2 * i], "error offsets"
pieces = split(args[0], sections, axis=0)
outputs = [piece.reshape(var_out.shape) for piece, var_out in zip(pieces, var_outs)]
return outputs
def _all_reduce(reducer, inp, world_size):
def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]):
groups = np.array(
list(itertools.zip_longest(*replica_groups, fillvalue=-1)), dtype=np.int64
).T
return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups))
replica_groups = _replica_groups_hlo([[i for i in range(world_size)]])
hlo_cfgs = {}
all_reduce_op = hlo.AllReduceOp(
inp.tensor.type, inp.tensor, replica_groups=replica_groups, **hlo_cfgs
)
scalar_type = ir_utils.make_ir_type_according_meta(tuple(), inp.dtype)
reducer_region = all_reduce_op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
reducer_ret = reducer(*reducer_region.arguments)
hlo.ReturnOp(reducer_ret.results)
return HLOTensor(all_reduce_op.results)
all_reduce_sum = partial(_all_reduce, hlo.AddOp)
all_reduce_prod = partial(_all_reduce, hlo.MulOp)
all_reduce_min = partial(_all_reduce, hlo.MinOp)
all_reduce_max = partial(_all_reduce, hlo.MaxOp)
@register_lower_rule(mops.CollectiveComm)
def collective_comm_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1, "collective comm only support one input"
if ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_SUM:
ret = all_reduce_sum(args[0], ctx.op.nr_devices)
elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_PROD:
ret = all_reduce_prod(args[0], ctx.op.nr_devices)
elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_MIN:
ret = all_reduce_min(args[0], ctx.op.nr_devices)
elif ctx.op.mode == mops.CollectiveComm.Mode.ALL_REDUCE_MAX:
ret = all_reduce_max(args[0], ctx.op.nr_devices)
else:
assert False, f"not support mode{ctx.op.mode}"
return ret
import math
from functools import partial
from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..lib.mlir.dialects import hlo
from .hlotensor import HLOTensor
from .utils import register_lower_rule
def _infer_elemwise_oshape(inp_shapes):
def _infer_binary_elemwise_oshape(lhs_shape, rhs_shape):
if len(lhs_shape) == 0:
return rhs_shape
if len(rhs_shape) == 0:
return lhs_shape
if np.prod(lhs_shape) == 1 and len(rhs_shape) != 0:
return rhs_shape
if np.prod(rhs_shape) == 1 and len(rhs_shape) != 0:
return lhs_shape
oshape = []
if len(lhs_shape) == len(rhs_shape):
for l, r in zip(lhs_shape, rhs_shape):
if l == r:
oshape.append(l)
elif l == 1:
oshape.append(r)
elif r == 1:
oshape.append(l)
else:
assert False, f"infer elemwise shape error: {lhs_shape} {rhs_shape}"
else:
shorter = lhs_shape if len(lhs_shape) < len(rhs_shape) else rhs_shape
longer = lhs_shape if len(lhs_shape) > len(rhs_shape) else rhs_shape
right_part = longer[-len(shorter) :]
for l, s in zip(right_part, shorter):
assert (
l == s or s == 1
), f"infer elemwise shape error: {lhs_shape} {rhs_shape}"
oshape = longer
return oshape
oshape = tuple()
for ishape in inp_shapes:
oshape = _infer_binary_elemwise_oshape(ishape, oshape)
return oshape
def _infer_elemwise_odtype(inp_dtypes):
oup_dtype = inp_dtypes[0]
for inp_dtype in inp_dtypes:
assert (
inp_dtype == oup_dtype
), f"elemwise inputs has different dtype {inp_dtypes}"
return oup_dtype
def _compare(lhs, rhs, mode, comparison_type=None):
"""
mod: can be
'EQ' (equal-to),
'NE' (not equal-to),
'GE' (greater-or-equal-than),
'GT' (greater-than),
'LE' (less-or-equal-than),
'LT' (less-than)
comparision_type: can be 'UNSIGNED', 'SIGNED', 'FLOAT'
"""
lhs = HLOTensor(lhs) if not isinstance(lhs, HLOTensor) else lhs
rhs = HLOTensor(rhs) if not isinstance(rhs, HLOTensor) else rhs
oshape = _infer_elemwise_oshape([lhs.shape, rhs.shape])
lhs = lhs.broadcast_to(oshape)
rhs = rhs.broadcast_to(oshape)
if comparison_type is None:
if lhs.dtype in [np.int64, np.int32, np.int16, np.int8]:
assert rhs.dtype in [np.int64, np.int32, np.int16, np.int8]
comparison_type = "SIGNED"
elif lhs.dtype in [np.uint64, np.uint32, np.uint16, np.uint8]:
assert rhs.dtype in [np.uint64, np.uint32, np.uint16, np.uint8]
comparison_type = "UNSIGNED"
elif lhs.dtype in [np.float64, np.float32, np.float16]:
assert rhs.dtype in [np.float64, np.float32, np.float16]
comparison_type = "FLOAT"
else:
assert False, f"invalid dtype for compare {lhs.dtype} .vs {rhs.dtype}"
return HLOTensor(
hlo.CompareOp(
lhs.tensor,
rhs.tensor,
hlo.ComparisonDirectionAttr.get(mode),
compare_type=hlo.ComparisonTypeAttr.get(comparison_type),
).result
)
def _elemwise(hlo_op, inps):
hinps = [HLOTensor(inp) if not isinstance(inp, HLOTensor) else inp for inp in inps]
ishapes = [inp.shape for inp in hinps]
idtypes = [inp.dtype for inp in hinps]
oshape = _infer_elemwise_oshape(ishapes)
odtype = _infer_elemwise_odtype(idtypes)
broadcasted_inps = [hinp.broadcast_to(oshape) for hinp in hinps]
results = hlo_op(*[binp.tensor for binp in broadcasted_inps]).results
assert len(results) == 1, f"elemwise op {hlo_op} should have only one output"
return HLOTensor(results[0], oshape, odtype)
def _elemwise_unary(hlo_op, a):
return _elemwise(hlo_op, [a])
def _elemwise_binary(hlo_op, a, b):
return _elemwise(hlo_op, [a, b])
neg = partial(_elemwise_unary, hlo.NegOp)
abs = partial(_elemwise_unary, hlo.AbsOp)
tanh = partial(_elemwise_unary, hlo.TanhOp)
exp = partial(_elemwise_unary, hlo.ExpOp)
sqrt = partial(_elemwise_unary, hlo.SqrtOp)
log = partial(_elemwise_unary, hlo.LogOp)
add = partial(_elemwise_binary, hlo.AddOp)
sub = partial(_elemwise_binary, hlo.SubtractOp)
mul = partial(_elemwise_binary, hlo.MulOp)
div = partial(_elemwise_binary, hlo.DivOp)
pow = partial(_elemwise_binary, hlo.PowOp)
equal = partial(_compare, mode="EQ")
not_equal = partial(_compare, mode="NE")
greater = partial(_compare, mode="GT")
greater_equal = partial(_compare, mode="GE")
less = partial(_compare, mode="LT")
less_equal = partial(_compare, mode="LE")
def abs_grad(x, dy):
return (x / abs(x)) * dy
def tanh_grad(x, dy):
return (1.0 - tanh(x) ** 2.0) * dy
def bitcast(inp, oshape, odtype):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype
return HLOTensor(
hlo.BitcastConvertOp(
ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor
).result
)
def typecvt(inp, odtype):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype
return HLOTensor(
hlo.ConvertOp(
ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor
).result
)
def gelu(inp, approximate: bool = True):
if approximate:
sqrt_2_over_pi = np.sqrt(2.0 / np.pi)
a = inp ** 3.0
b = 0.044715 * a
c = inp + b
d = sqrt_2_over_pi * c
e = tanh(d)
f = 1.0 + e
g = 0.5 * f
h = inp * g
else:
assert False, "only approximate gelu is supported"
return h
def erfcc(inp):
_a = abs(inp)
_b = 0.5 * _a
_c = 1.0 + _b
_d = 1.0 / _c
_e = _d * 0.17087277
_f = -0.82215223 + _e
_g = _d * _f
_h = 1.48851587 + _g
_i = _d * _h
_j = -1.13520398 + _i
_k = _d * _j
_l = 0.27886807 + _k
_m = _d * _l
_n = -0.18628806 + _m
_o = _d * _n
_p = 0.09678418 + _o
_q = _d * _p
_r = 0.37409196 + _q
_s = _d * _r
_t = 1.00002368 + _s
_u = _d * _t
_v = inp * inp
_w = -_v
_x = _w - 1.26551223
_y = _x + _u
_z = exp(_y)
_aa = _d * _z
_ab = 1.0 - _aa
_ac = -_ab
_ad = (inp >= 0.0).astype(inp.dtype)
_ae = (inp < 0.0).astype(inp.dtype)
_af = _ad * _ab
_ag = _ae * _ac
ret = _af + _ag
return ret
def gelu_grad(x, dy, approximate: bool = True):
if approximate:
_a = x * x
_b = -0.5 * _a
_c = exp(_b)
phi = 0.3989422804014327 * _c
_d = x / math.sqrt(2.0)
_e = erfcc(_d)
_f = 1.0 + _e
normcdf_v = 0.5 * _f
_g = x * phi
_h = normcdf_v + _g
ret = dy * _h
else:
assert False
return ret
def relu(inp):
mask = (inp > 0.0).astype(inp.dtype)
return inp * mask
def relu_grad(x, dy):
mask = (x > 0.0).astype(x.dtype)
return dy * mask
# Elemwise.Mode is unhashable, so we convert it to str
mge_elemwise_to_xla = {
str(mops.Elemwise.Mode.ADD): add,
str(mops.Elemwise.Mode.MUL): mul,
str(mops.Elemwise.Mode.SUB): sub,
str(mops.Elemwise.Mode.EXP): exp,
str(mops.Elemwise.Mode.LOG): log,
str(mops.Elemwise.Mode.GELU): gelu,
str(mops.Elemwise.Mode.GELU_GRAD): gelu_grad,
str(mops.Elemwise.Mode.TRUE_DIV): div,
str(mops.Elemwise.Mode.NEGATE): neg,
str(mops.Elemwise.Mode.ABS): abs,
str(mops.Elemwise.Mode.ABS_GRAD): abs_grad,
str(mops.Elemwise.Mode.TANH): tanh,
str(mops.Elemwise.Mode.TANH_GRAD): tanh_grad,
str(mops.Elemwise.Mode.SQRT): sqrt,
str(mops.Elemwise.Mode.POW): pow,
str(mops.Elemwise.Mode.RELU): relu,
str(mops.Elemwise.Mode.EQ): equal,
str(mops.Elemwise.Mode.NEQ): not_equal,
str(mops.Elemwise.Mode.LT): less,
str(mops.Elemwise.Mode.LEQ): less_equal,
str(mops.Elemwise.Mode.SWITCH_GT0): relu_grad,
}
@register_lower_rule(mops.Elemwise)
def elemwise_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert isinstance(ctx.op, mops.Elemwise), "op should be elemwise here"
assert (
len(ctx.vars_out) == 1
), f"elemwise output num should be 1, got {len(ctx.vars_out)}"
handle = mge_elemwise_to_xla[str(ctx.op.mode)]
oup = handle(*args)
return oup
@register_lower_rule(mops.ElemwiseMultiType)
def elemwise_multi_type_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
opr = ctx.op
mode = "Elemwise.Mode." + str(opr.mode).split(".")[-1]
handle = mge_elemwise_to_xla[mode]
oup = handle(*args).astype(opr.dtype)
return oup
from typing import Sequence
import numpy as np
from .. import ir_utils
from ..ir_utils import get_irnode_dtype, get_irnode_shape
from ..lib.mlir import ir
from .utils import _check_dtype, _check_shape
class HLOTensor:
def __init__(self, tensor, shape=None, dtype=None) -> None:
if isinstance(tensor, Sequence):
assert len(tensor) > 0, "cannot create HLOTensor from empty sequence"
if isinstance(tensor[0], int):
tensor = np.array(tensor)
else:
assert len(tensor) == 1, f"cannot create HLOTensor from {tensor}"
tensor = tensor[0]
if isinstance(tensor, ir.OpResultList):
assert len(tensor) == 1, f"cannot create HLOTensor from {tensor}"
tensor = tensor[0]
if isinstance(
tensor, (int, float, np.int_, np.float16, np.float32, np.float64)
):
tensor = ir_utils.ir_constant(tensor)
elif isinstance(tensor, np.ndarray):
tensor = ir_utils.ir_constant(tensor)
else:
pass
assert isinstance(
tensor, (ir.RankedTensorType, ir.BlockArgument, ir.OpResult)
), type(tensor)
infered_shape = get_irnode_shape(tensor)
infered_dtype = get_irnode_dtype(tensor)
_check_shape(infered_shape, shape)
_check_dtype(infered_dtype, dtype)
self._tensor = tensor
self._shape = infered_shape
self._dtype = infered_dtype
@property
def shape(self):
return tuple(self._shape)
@property
def dtype(self):
return self._dtype
@property
def ndim(self):
return len(self.shape)
@property
def tensor(self):
return self._tensor
def __str__(self):
return f"HLOTensor(shape={self.shape}, dtype={self.dtype})"
def __eq__(self, rhs):
from .elemwise import equal
return equal(self, rhs)
def __ne__(self, rhs):
from .elemwise import not_equal
return not_equal(self, rhs)
def __gt__(self, rhs):
from .elemwise import greater
return greater(self, rhs)
def __ge__(self, rhs):
from .elemwise import greater_equal
return greater_equal(self, rhs)
def __lt__(self, rhs):
from .elemwise import less
return less(self, rhs)
def __le__(self, rhs):
from .elemwise import less_equal
return less_equal(self, rhs)
def __neg__(self):
from .elemwise import neg
return neg(self)
def __add__(self, rhs):
from .elemwise import add
return add(self, rhs)
def __radd__(self, rhs):
from .elemwise import add
return add(rhs, self)
def __sub__(self, rhs):
from .elemwise import sub
return sub(self, rhs)
def __rsub__(self, rhs):
from .elemwise import sub
return sub(rhs, self)
def __mul__(self, rhs):
from .elemwise import mul
return mul(self, rhs)
def __rmul__(self, rhs):
from .elemwise import mul
return mul(rhs, self)
def __truediv__(self, rhs):
from .elemwise import div
return div(self, rhs)
def __rtruediv__(self, rhs):
from .elemwise import div
return div(rhs, self)
def __pow__(self, rhs):
from .elemwise import pow
return pow(self, rhs)
def reshape(self, shape):
from .tensor import reshape
return reshape(self, shape)
def transpose(self, permutation):
from .tensor import transpose
return transpose(self, permutation)
def broadcast_to(self, shape, broadcast_dims=None):
from .tensor import broadcast_to
return broadcast_to(self, shape, broadcast_dims)
def bitcast(self, shape, dtype):
from .elemwise import bitcast
return bitcast(self, shape, dtype)
def astype(self, dtype):
from .elemwise import typecvt
return typecvt(self, dtype)
def sum(self, axis, keepdims=False):
from .reduction import sum
return sum(self, axis, keepdims)
def mean(self, axis, keepdims=False):
from .reduction import mean
return mean(self, axis, keepdims)
def prod(self, axis, keepdims=False):
from .reduction import prod
return prod(self, axis, keepdims)
def max(self, axis, keepdims=False):
from .reduction import max
return max(self, axis, keepdims)
def min(self, axis, keepdims=False):
from .reduction import min
return min(self, axis, keepdims)
def all(self, axis, keepdims=False):
from .reduction import all
return all(self, axis, keepdims)
def any(self, axis, keepdims=False):
from .reduction import any
return any(self, axis, keepdims)
from typing import Sequence, Union
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..lib.mlir.dialects import hlo
from .hlotensor import HLOTensor
from .utils import _can_broadcast_to, _shape_equal, register_lower_rule
@register_lower_rule(mops.Dot)
def dot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2
), f"{len(ctx.vars_in)}, {len(ctx.vars_out)}, {len(args)}"
assert args[0].ndim == 1 and args[1].ndim == 1, f"{args[0].shape}, {args[1].shape}"
assert args[0].shape[0] == args[1].shape[0], f"{args[0].shape}, {args[1].shape}"
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=tuple(),
rhs_batching_dimensions=tuple(),
lhs_contracting_dimensions=(0,),
rhs_contracting_dimensions=(0,),
)
return [
HLOTensor(
hlo.DotGeneralOp(
ir_utils.make_ir_type_according_meta((), ctx.vars_out[0].dtype),
args[0].tensor,
args[1].tensor,
dot_dnums,
precision_config=ir_utils.precision_attr(args[0].dtype, args[1].dtype),
).result
).reshape(ctx.vars_out[0].shape)
]
@register_lower_rule(mops.MatrixMul)
def matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2
assert (
ctx.op.compute_mode == mops.BatchedMatrixMul.ComputeMode.DEFAULT
), f"{ctx.op.compute_mode}"
assert ctx.op.format == mops.BatchedMatrixMul.Format.DEFAULT, f"{ctx.op.format}"
assert ctx.op.dimA == len(args[0].shape) and ctx.op.dimB == len(
args[1].shape
), f"{ctx.op.dimA}, {ctx.op.dimB}, {args[0].shape}, {args[1].shape}"
assert args[0].ndim >= 2 and args[1].ndim >= 2, f"{args[0].shape}, {args[1].shape}"
lhs, rhs = args[0], args[1]
# in mge batchmatmul, [a, b, c, d] * [a, b, c, f] -> [a, b, f, d]
# but in mge matmul, dims [:-1] is interpreted as one edge of matrix
# that means [a, b, c, d] * [a, b, c, f] -> [a*b*c, d] * [a*b*c, f] -> [f, d]
if lhs.ndim > 2 and rhs.ndim > 2:
lhs = lhs.reshape(shape=(-1, lhs.shape[-1]))
rhs = rhs.reshape(shape=(-1, rhs.shape[-1]))
lhs_reduce_axis = lhs.ndim - 2 if ctx.op.transposeA else lhs.ndim - 1
rhs_reduce_axis = rhs.ndim - 1 if ctx.op.transposeB else rhs.ndim - 2
assert (
lhs.shape[lhs_reduce_axis] == rhs.shape[rhs_reduce_axis]
), f"reduce axis length mismatch: {lhs.shape}, {rhs.shape}, {lhs_reduce_axis}, {rhs_reduce_axis}"
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=tuple(),
rhs_batching_dimensions=tuple(),
lhs_contracting_dimensions=(lhs_reduce_axis,),
rhs_contracting_dimensions=(rhs_reduce_axis,),
)
return [
HLOTensor(
hlo.DotGeneralOp(
ir_utils.mge_varinfo_to_ir_type(ctx.vars_out[0]),
lhs.tensor,
rhs.tensor,
dot_dnums,
precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype),
).result
)
]
def _bmm_shape_helper(lhs_shape, rhs_shape, lhs_transpose, rhs_transpose):
lhs_reduce_axis = len(lhs_shape) - 2 if lhs_transpose else len(lhs_shape) - 1
rhs_reduce_axis = len(rhs_shape) - 1 if rhs_transpose else len(rhs_shape) - 2
# get the shape of inputs after transpose
lhs_shape, rhs_shape = list(lhs_shape), list(rhs_shape)
if lhs_transpose:
lhs_shape[-2], lhs_shape[-1] = lhs_shape[-1], lhs_shape[-2]
if rhs_transpose:
rhs_shape[-2], rhs_shape[-1] = rhs_shape[-1], rhs_shape[-2]
# get the batch info of inputs
lhs_prefix, rhs_prefix = lhs_shape[:-2], rhs_shape[:-2]
# only the batch of input_a can broadcast to input_b supported
assert _can_broadcast_to(lhs_prefix, rhs_prefix) or _can_broadcast_to(
rhs_prefix, lhs_prefix
), f"{lhs_shape}, {rhs_shape}"
# get the batch axis of input_a and input_b, for example:
# (3, 4, 5) * (3, 5, 6), the batch axis is (0,) and (0,)
# (3, 4, 5) * (2, 3, 5, 6), the batch axis is (0,) and (1,)
# (2, 3, 4, 5) * (2, 3, 5, 6), the batch axis is (0, 1) and (0, 1)
lhs_batch_axis, rhs_batch_axis = [], []
min_len = min(len(lhs_shape), len(rhs_shape))
for i in range(-3, -min_len - 1, -1):
if lhs_shape[i] == rhs_shape[i]:
lhs_batch_axis.append(i)
rhs_batch_axis.append(i)
elif lhs_shape[i] == 1 or rhs_shape[i] == 1:
lhs_batch_axis.append(i)
rhs_batch_axis.append(i)
else:
break
lhs_batch_axis = [val + len(lhs_shape) for val in lhs_batch_axis]
rhs_batch_axis = [val + len(rhs_shape) for val in rhs_batch_axis]
lhs_batch_axis.sort()
rhs_batch_axis.sort()
assert len(lhs_batch_axis) == len(lhs_prefix) or len(rhs_batch_axis) == len(
rhs_prefix
), f"{lhs_batch_axis}, {rhs_batch_axis}, {lhs_prefix}, {rhs_prefix}, {lhs_shape}, {rhs_shape}"
# for case [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c]
if _can_broadcast_to(lhs_prefix, rhs_prefix):
# [m, ..., n]
batched_part = [rhs_prefix[ax] for ax in rhs_batch_axis]
# [i, ..., j]
nonbatched_part = rhs_prefix[0 : len(rhs_prefix) - len(rhs_batch_axis)]
# in xla, [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c] -> [m, ..., n, a, i, ..., j, c]
# in mge, [m, ... , n, a, b] * [i, ..., j, m, ..., n, b, c] -> [i, ..., j, m, ..., n, a, c]
# so we need permute
xla_oshape = [*batched_part, lhs_shape[-2], *nonbatched_part, rhs_shape[-1]]
nonbatched_perm = [
idx + 1 + len(batched_part) for idx in range(len(nonbatched_part))
]
batched_perm = [idx for idx in range(len(batched_part))]
permutation = [
*nonbatched_perm,
*batched_perm,
len(batched_part),
len(xla_oshape) - 1,
]
# for case [i, ..., j, m, ..., n, a, b] * [m, ..., n, b, c]
else:
# [m, ..., n]
batched_part = [lhs_prefix[ax] for ax in lhs_batch_axis]
# [i, ..., j]
nonbatched_part = lhs_prefix[0 : len(lhs_prefix) - len(lhs_batch_axis)]
# in xla, [i, ..., j, m, ... , n, a, b] * [m, ..., n, b, c] -> [m, ..., n, i, ..., j, a, c]
# in mge, [i, ..., j, m, ... , n, a, b] * [m, ..., n, b, c] -> [i, ..., j, m, ..., n, a, c]
# so we need permute
xla_oshape = [*batched_part, *nonbatched_part, lhs_shape[-2], rhs_shape[-1]]
nonbatched_perm = [
idx + len(batched_part) for idx in range(len(nonbatched_part))
]
batched_perm = [idx for idx in range(len(batched_part))]
permutation = [
*nonbatched_perm,
*batched_perm,
len(xla_oshape) - 2,
len(xla_oshape) - 1,
]
return (
lhs_reduce_axis,
rhs_reduce_axis,
lhs_batch_axis,
rhs_batch_axis,
xla_oshape,
permutation,
)
@register_lower_rule(mops.BatchedMatrixMul)
def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1 and len(args) == 2
assert (
ctx.op.compute_mode == mops.BatchedMatrixMul.ComputeMode.DEFAULT
), f"{ctx.op.compute_mode}"
assert ctx.op.format == mops.BatchedMatrixMul.Format.DEFAULT, f"{ctx.op.format}"
assert ctx.op.dimA == len(args[0].shape) and ctx.op.dimB == len(
args[1].shape
), f"{ctx.op.dimA}, {ctx.op.dimB}, {args[0].shape}, {args[1].shape}"
assert args[0].ndim >= 2 and args[1].ndim >= 2, f"{args[0].shape}, {args[1].shape}"
lhs, rhs = args[0], args[1]
(
lhs_reduce_axis,
rhs_reduce_axis,
lhs_batch_axis,
rhs_batch_axis,
xla_oshape,
permutation,
) = _bmm_shape_helper(lhs.shape, rhs.shape, ctx.op.transposeA, ctx.op.transposeB)
# in xla, [3, 4, 5, 6] * [3, 1, 6, 7] is illegal, so we broadcast [3, 1, 6, 7] -> [3, 4, 6, 7]
if _can_broadcast_to(lhs.shape[:-2], rhs.shape[:-2]):
lshape = [
rhs.shape[r] if lhs.shape[l] == 1 else lhs.shape[l]
for l, r in zip(lhs_batch_axis, rhs_batch_axis)
]
lshape = [*lshape, *lhs.shape[-2:]]
if not _shape_equal(lshape, lhs.shape):
lhs = lhs.broadcast_to(lshape)
else:
assert _can_broadcast_to(rhs.shape[:-2], lhs.shape[:-2])
rshape = [
lhs.shape[l] if rhs.shape[r] == 1 else rhs.shape[r]
for l, r in zip(lhs_batch_axis, rhs_batch_axis)
]
rshape = [*rshape, *rhs.shape[-2:]]
if not _shape_equal(rshape, rhs.shape):
rhs = rhs.broadcast_to(rshape)
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=list(lhs_batch_axis),
rhs_batching_dimensions=list(rhs_batch_axis),
lhs_contracting_dimensions=(lhs_reduce_axis,), # the reduce axis in lhs
rhs_contracting_dimensions=(rhs_reduce_axis,), # the reduce axis in rhs
)
return HLOTensor(
hlo.DotGeneralOp(
ir_utils.make_ir_type_according_meta(xla_oshape, ctx.vars_out[0].dtype),
lhs.tensor,
rhs.tensor,
dot_dnums,
precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype),
).result
).transpose(permutation)
此差异已折叠。
from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..lib import xla_client as xc
from ..lib.mlir.dialects import hlo
from .hlotensor import HLOTensor
from .utils import _shape_equal, register_lower_rule
RandomAlgorithm = xc.ops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
def _rng_algorithm(algorithm: RandomAlgorithm):
assert algorithm == RandomAlgorithm.RNG_THREE_FRY
if algorithm == RandomAlgorithm.RNG_THREE_FRY:
return hlo.RngAlgorithmAttr.get("THREE_FRY")
elif algorithm == RandomAlgorithm.RNG_PHILOX:
return hlo.RngAlgorithmAttr.get("PHILOX")
elif algorithm == RandomAlgorithm.RNG_DEFAULT:
return hlo.RngAlgorithmAttr.get("DEFAULT")
else:
assert False
def rng_uint_generator(
key, oshape, odtype="uint32", algorithm=RandomAlgorithm.RNG_THREE_FRY
):
assert np.dtype(odtype) in {
np.dtype("uint8"),
np.dtype("uint16"),
np.dtype("uint32"),
np.dtype("uint64"),
}, f"only unsigned int supported, got {odtype}({type(odtype)})"
assert algorithm == RandomAlgorithm.RNG_THREE_FRY, "only ThreeFry supported now"
assert _shape_equal(key.shape, (2, 2)), f"key shape error, {key.shape}"
assert key.dtype == "int32", f"key dtype error, {key.dtype}"
# bitcast (2x2,i32) -> (2,u64)
org_key_shape, org_key_dtype = key.shape, key.dtype
key = key.bitcast((2,), "uint64")
if odtype == "uint32" or odtype == "uint64":
rng_odtype = odtype
else:
rng_odtype = "uint32"
algorithm_attr = _rng_algorithm(algorithm)
new_key, out_vals = hlo.RngBitGeneratorOp(
ir_utils.make_ir_type_according_meta(key.shape, key.dtype),
ir_utils.make_ir_type_according_meta(oshape, rng_odtype),
algorithm_attr,
key.tensor,
).results
new_key, out_vals = HLOTensor(new_key), HLOTensor(out_vals)
new_key = new_key.bitcast(org_key_shape, org_key_dtype)
if rng_odtype != odtype:
out_vals = out_vals.astype(odtype)
return out_vals, new_key
@register_lower_rule(mops.Dropout)
def dropout_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(ctx.vars_in) == 2 and len(args) == 2 and len(ctx.vars_out) == 3
inp, key = args
random_val, new_key = rng_uint_generator(key, inp.shape, "uint32")
mask = random_val > np.array(ctx.op.drop_prob * np.iinfo(np.uint32).max, np.uint32)
multiplier = mask.astype(inp.dtype)
multiplier = multiplier / (1.0 - ctx.op.drop_prob)
out = inp * multiplier
mask = mask.reshape((-1,)).astype("uint8")
return out, mask, new_key
@register_lower_rule("DropoutBackward")
def droupout_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 2 and len(ctx.vars_in) == 2 and len(ctx.vars_out) == 1
dy, mask = args[0], args[1]
scale = 1.0 - ctx.param["drop_prob"]
multiplier = mask.reshape(dy.shape).astype(dy.dtype) / scale
return dy * multiplier
from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops
from ..lib.mlir import ir
from .hlotensor import HLOTensor
from .utils import _check_shape, register_lower_rule
@register_lower_rule(mops.GetVarShape)
def get_var_shape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if len(args) > 1:
assert len(args) == 2, f"{len(args)}"
_check_shape(args[0].shape, args[1].shape)
shp = args[0].shape
if ctx.op.axis != 7:
shp = (shp[ctx.op.axis],)
shp = np.array(shp, np.int64)
ctx.module_context.set_value(ctx.vars_out[0], shp)
return HLOTensor(shp)
@register_lower_rule("create_tensor")
def create_tensor_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == len(ctx.vars_in) == len(ctx.vars_out) == 1
var_in, var_out = ctx.vars_in[0], ctx.vars_out[0]
if var_in.bound_data is not None:
ctx.module_context.set_value(var_in, var_in.bound_data)
ctx.module_context.set_value(var_out, var_in.bound_data)
assert var_in.shape == var_out.shape
if var_out.bound_data is not None:
data = np.asarray(var_out.bound_data, var_out.dtype)
elif var_in.bound_data is not None:
data = np.asarray(var_in.bound_data, var_out.dtype)
else:
assert False, "only support create tensor from const now"
return HLOTensor(data)
@register_lower_rule("io_mark_var")
def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
assert len(args) == 1
return args
@register_lower_rule("rename")
def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
assert len(args) == 1
return args
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册