From 4c7905f3d40535cefb6f948b6809af965dc647fc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 6 Jul 2023 18:55:45 +0800 Subject: [PATCH] feat(imperative): add some xla op rules GitOrigin-RevId: 0650c75dc1e4ec9af8ae7d9ed3eca60e4681e04a --- .../python/megengine/jit/partial_tracing.py | 4 +- imperative/python/megengine/jit/tracing.py | 2 +- imperative/python/megengine/xla/ir_utils.py | 5 +- imperative/python/megengine/xla/lower.py | 2 +- .../python/megengine/xla/rules/elemwise.py | 264 ++++++++++++++++-- imperative/python/megengine/xla/rules/math.py | 9 +- .../python/megengine/xla/rules/reduction.py | 10 +- .../python/megengine/xla/rules/tensor.py | 5 + .../python/megengine/xla/rules/trivial.py | 6 + .../python/megengine/xla/rules/utils.py | 16 +- imperative/python/src/grad_override.cpp | 2 +- imperative/python/src/tensor.cpp | 3 +- .../python/test/unit/jit/test_tracing.py | 5 +- .../unit/xla/functional/test_xla_elemwise.py | 139 ++++----- .../test/unit/xla/functional/test_xla_nn.py | 67 +++++ .../test/unit/xla/module/test_elemwise.py | 49 ++++ src/plugin/impl/opr_footprint.cpp | 2 + 17 files changed, 495 insertions(+), 95 deletions(-) create mode 100644 imperative/python/test/unit/xla/module/test_elemwise.py diff --git a/imperative/python/megengine/jit/partial_tracing.py b/imperative/python/megengine/jit/partial_tracing.py index 1c5240cec..3753c2511 100644 --- a/imperative/python/megengine/jit/partial_tracing.py +++ b/imperative/python/megengine/jit/partial_tracing.py @@ -75,7 +75,9 @@ def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map): def check_external(trace_obj): for var in trace_obj.vars: if var.kind == "external" and not var.inp_mark: - raise RuntimeError("have unknown input in trace result") + raise RuntimeError( + "have unknown input in trace result, maybe you can set `capture_as_const=True` when trace" + ) check_external(fwd) check_external(bwd) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 6f2b22d52..114d6fca0 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -579,7 +579,7 @@ class trace: if not self._trace.compiled(): outlist, self.outdef = tree_flatten(outputs) for i, out in enumerate(outlist): - assert isinstance(out, RawTensor) + assert isinstance(out, RawTensor), f"get out of type {type(out)}" outlist[i] = get_marked_output_tensor(self.output_num, out) del out self.out_list.append(self.output_num) diff --git a/imperative/python/megengine/xla/ir_utils.py b/imperative/python/megengine/xla/ir_utils.py index 70d909462..a8e80b287 100644 --- a/imperative/python/megengine/xla/ir_utils.py +++ b/imperative/python/megengine/xla/ir_utils.py @@ -101,10 +101,11 @@ class DropoutMaskCanonicalizer(Pass): if not isinstance(eqn.op, mops.Dropout): continue - outputs = list(eqn.outputs) + inputs, outputs = list(eqn.inputs), list(eqn.outputs) mask_var = tr.vars[outputs[1]] + inp_shape = tr.vars[inputs[0]].shape new_mask_var = AbstractVar( - mask_var.id, (int(np.prod(mask_var.shape)) * 8,), mask_var.dtype + mask_var.id, (int(np.prod(inp_shape)),), mask_var.dtype ) tr.vars[mask_var.id] = new_mask_var diff --git a/imperative/python/megengine/xla/lower.py b/imperative/python/megengine/xla/lower.py index a43717571..847771029 100644 --- a/imperative/python/megengine/xla/lower.py +++ b/imperative/python/megengine/xla/lower.py @@ -142,7 +142,7 @@ def lowering_ops( vars_out=[trace_result.vars[oup] for oup in eqn.outputs], param=eqn.param, ) - rule = get_rule(eqn.op) + rule = get_rule(eqn.op, use_fake_rule_for_debug=False) in_nodes = read(eqn.inputs) hinps = [ diff --git a/imperative/python/megengine/xla/rules/elemwise.py b/imperative/python/megengine/xla/rules/elemwise.py index dc8803120..82c5121dc 100644 --- a/imperative/python/megengine/xla/rules/elemwise.py +++ b/imperative/python/megengine/xla/rules/elemwise.py @@ -18,9 +18,9 @@ def _infer_elemwise_oshape(inp_shapes): if len(rhs_shape) == 0: return lhs_shape - if np.prod(lhs_shape) == 1 and len(rhs_shape) != 0: + if np.prod(lhs_shape) == 1 and len(lhs_shape) == 1 and len(rhs_shape) != 0: return rhs_shape - if np.prod(rhs_shape) == 1 and len(rhs_shape) != 0: + if np.prod(rhs_shape) == 1 and len(rhs_shape) == 1 and len(rhs_shape) != 0: return lhs_shape oshape = [] @@ -62,6 +62,24 @@ def _infer_elemwise_odtype(inp_dtypes): return oup_dtype +def bitcast(inp, oshape, odtype): + odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype + return HLOTensor( + hlo.BitcastConvertOp( + ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor + ).result + ) + + +def typecvt(inp, odtype): + odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype + return HLOTensor( + hlo.ConvertOp( + ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor + ).result + ) + + def _compare(lhs, rhs, mode, comparison_type=None): """ mod: can be @@ -126,19 +144,36 @@ def _elemwise_binary(hlo_op, a, b): return _elemwise(hlo_op, [a, b]) +def _elemwise_ternary(hlo_op, a, b, c): + return _elemwise(hlo_op, [a, b, c]) + + neg = partial(_elemwise_unary, hlo.NegOp) abs = partial(_elemwise_unary, hlo.AbsOp) +sin = partial(_elemwise_unary, hlo.SineOp) +cos = partial(_elemwise_unary, hlo.CosineOp) tanh = partial(_elemwise_unary, hlo.TanhOp) exp = partial(_elemwise_unary, hlo.ExpOp) sqrt = partial(_elemwise_unary, hlo.SqrtOp) log = partial(_elemwise_unary, hlo.LogOp) +log1p = partial(_elemwise_unary, hlo.Log1pOp) +expm1 = partial(_elemwise_unary, hlo.Expm1Op) +floor = partial(_elemwise_unary, hlo.FloorOp) +ceil = partial(_elemwise_unary, hlo.CeilOp) +round = partial(_elemwise_unary, hlo.RoundOp) add = partial(_elemwise_binary, hlo.AddOp) sub = partial(_elemwise_binary, hlo.SubtractOp) mul = partial(_elemwise_binary, hlo.MulOp) div = partial(_elemwise_binary, hlo.DivOp) pow = partial(_elemwise_binary, hlo.PowOp) +maximum = partial(_elemwise_binary, hlo.MaxOp) +minimum = partial(_elemwise_binary, hlo.MinOp) +atan2 = partial(_elemwise_binary, hlo.Atan2Op) +left_shift = partial(_elemwise_binary, hlo.ShiftLeftOp) +right_shift = partial(_elemwise_binary, hlo.ShiftRightArithmeticOp) +clip = partial(_elemwise_ternary, hlo.ClampOp) equal = partial(_compare, mode="EQ") not_equal = partial(_compare, mode="NE") @@ -147,31 +182,99 @@ greater_equal = partial(_compare, mode="GE") less = partial(_compare, mode="LT") less_equal = partial(_compare, mode="LE") +logical_and = partial(_elemwise_binary, hlo.AndOp) +logical_or = partial(_elemwise_binary, hlo.OrOp) +logical_not = partial(_elemwise_unary, hlo.NotOp) +logical_xor = partial(_elemwise_binary, hlo.XorOp) + + +def floor_div(x, y): + return floor(div(x, y)) + + +def mod(x, y): + assert False, "xla not support" + + +def cond_leq_move(x, y, z): + mask = (x <= y).astype(x.dtype) + return mask * z + + +def cond_lt_move(x, y, z): + mask = (x < y).astype(x.dtype) + return mask * z + + +def log_add_exp(x, y): + min_val = minimum(x, y) + max_val = maximum(x, y) + return max_val + log1p(exp(min_val - max_val)) + + +def square(x): + return mul(x, x) + def abs_grad(x, dy): return (x / abs(x)) * dy +def tan(x): + return sin(x) / cos(x) + + +def tan_grad(x, dy): + return (1.0 + tan(x) ** 2.0) * dy + + +def sinh(x): + return (exp(x) - exp(-x)) / 2.0 + + +def cosh(x): + return (exp(x) + exp(-x)) / 2.0 + + def tanh_grad(x, dy): return (1.0 - tanh(x) ** 2.0) * dy -def bitcast(inp, oshape, odtype): - odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype - return HLOTensor( - hlo.BitcastConvertOp( - ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor - ).result - ) +def atan(x): + return atan2(x, 1.0) -def typecvt(inp, odtype): - odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype - return HLOTensor( - hlo.ConvertOp( - ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor - ).result - ) +def asin(x): + return atan(x / sqrt(1.0 - x ** 2.0)) + + +def acos(x): + assert False, "xla not support" + # return atan(sqrt(1.0 - x ** 2.0) / x) + + +def asinh(x): + return log(x + sqrt(x ** 2.0 + 1.0)) + + +def acosh(x): + return log(x + sqrt(x ** 2.0 - 1.0)) + + +def atanh(x): + return log((1.0 + x) / (1.0 - x)) / 2.0 + + +def asinh_grad(x, dy): + return dy / sqrt(x ** 2.0 + 1.0) + + +def acosh_grad(x, dy): + return dy / sqrt(x ** 2.0 - 1.0) + + +def atanh_grad(x, dy): + return dy / (1.0 - x ** 2.0) def gelu(inp, approximate: bool = True): @@ -257,6 +360,86 @@ def relu_grad(x, dy): return dy * mask +def sigmoid(inp): + return 1.0 / (1.0 + exp(-inp)) + + +def sigmoid_grad(y, dy): + return y * (1.0 - y) * dy + + +def hsigmoid(x): + from .tensor import where + + return where(x <= -3.0, 0.0, where(x >= 3.0, 1.0, (x + 3.0) / 6.0)) + + +def hsigmoid_grad(x, dy): + from .tensor import where + + return where(x <= -3.0, 0.0, where(x >= 3.0, 0.0, dy / 6.0)) + + +def relu6(x): + return clip(x, 0.0, 6.0) + + +def relu6_grad(x, dy): + from .tensor import where + + return where(x <= 0.0, 0.0, where(x >= 6.0, 0.0, dy)) + + +def hswish(x): + return x * minimum(maximum(x + 3.0, 0.0), 6.0) * (1.0 / 6.0) + + +def hswish_grad(x, dy): + from .tensor import where + + return where(x < -3.0, 0.0, where(x > 3.0, dy, (2.0 * x + 3.0) / 6.0 * dy)) + + +def logsigmoid(x): + from .tensor import where + + return -log1p(exp(-abs(x))) + where(x >= 0.0, 0.0, x) + + +def softplus(x): + return log1p(exp(-abs(x))) + relu(x) + + +def softplus_grad(x, dy): + from .tensor import where + + exp_abs = exp(-abs(x)) + logg = -dy * exp_abs / (1.0 + exp_abs) + grad0 = where(x > 0.0, logg, -logg) + relux = relu(x) + grad1 = where(relux > 0.0, dy, 0.0) + return grad0 + grad1 + + +def prelu(inp, alpha): + mask = (inp > 0.0).astype(inp.dtype) + return inp * mask + alpha * (1.0 - mask) * inp + + +def prelu_grad(x, dy, alpha): + mask = (x > 0.0).astype(x.dtype) + return dy * mask + alpha * (1.0 - mask) * dy + + +def silu(inp): + return inp / (1.0 + exp(-inp)) + + +def silu_grad(x, dy): + xsig = sigmoid(x) + return dy * xsig * (1.0 + x * (1.0 - xsig)) + + # Elemwise.Mode is unhashable, so we convert it to str mge_elemwise_to_xla = { str(mops.Elemwise.Mode.ADD): add, @@ -264,22 +447,71 @@ mge_elemwise_to_xla = { str(mops.Elemwise.Mode.SUB): sub, str(mops.Elemwise.Mode.EXP): exp, str(mops.Elemwise.Mode.LOG): log, + str(mops.Elemwise.Mode.LOG1P): log1p, + str(mops.Elemwise.Mode.LOG_SUM_EXP): log_add_exp, + str(mops.Elemwise.Mode.MAX): maximum, + str(mops.Elemwise.Mode.MIN): minimum, + str(mops.Elemwise.Mode.COND_LEQ_MOV): cond_leq_move, + str(mops.Elemwise.Mode.COND_LT_MOV): cond_lt_move, + str(mops.Elemwise.Mode.FLOOR): floor, + str(mops.Elemwise.Mode.CEIL): ceil, + str(mops.Elemwise.Mode.ROUND): round, + str(mops.Elemwise.Mode.CLIP): clip, str(mops.Elemwise.Mode.GELU): gelu, str(mops.Elemwise.Mode.GELU_GRAD): gelu_grad, str(mops.Elemwise.Mode.TRUE_DIV): div, str(mops.Elemwise.Mode.NEGATE): neg, + str(mops.Elemwise.Mode.FLOOR_DIV): floor_div, + str(mops.Elemwise.Mode.MOD): mod, str(mops.Elemwise.Mode.ABS): abs, str(mops.Elemwise.Mode.ABS_GRAD): abs_grad, + str(mops.Elemwise.Mode.SIN): sin, + str(mops.Elemwise.Mode.COS): cos, + str(mops.Elemwise.Mode.TAN): tan, + str(mops.Elemwise.Mode.SINH): sinh, + str(mops.Elemwise.Mode.COSH): cosh, str(mops.Elemwise.Mode.TANH): tanh, + str(mops.Elemwise.Mode.ASIN): asin, + str(mops.Elemwise.Mode.ACOS): acos, + str(mops.Elemwise.Mode.ASINH): asinh, + str(mops.Elemwise.Mode.ACOSH): acosh, + str(mops.Elemwise.Mode.ATANH): atanh, + str(mops.Elemwise.Mode.ATAN2): atan2, str(mops.Elemwise.Mode.TANH_GRAD): tanh_grad, + str(mops.Elemwise.Mode.ASINH_GRAD): asinh_grad, + str(mops.Elemwise.Mode.ACOSH_GRAD): acosh_grad, + str(mops.Elemwise.Mode.ATANH_GRAD): atanh_grad, str(mops.Elemwise.Mode.SQRT): sqrt, + str(mops.Elemwise.Mode.SQUARE): square, str(mops.Elemwise.Mode.POW): pow, + str(mops.Elemwise.Mode.EXPM1): expm1, str(mops.Elemwise.Mode.RELU): relu, str(mops.Elemwise.Mode.EQ): equal, str(mops.Elemwise.Mode.NEQ): not_equal, str(mops.Elemwise.Mode.LT): less, str(mops.Elemwise.Mode.LEQ): less_equal, + str(mops.Elemwise.Mode.AND): logical_and, + str(mops.Elemwise.Mode.OR): logical_or, + str(mops.Elemwise.Mode.NOT): logical_not, + str(mops.Elemwise.Mode.XOR): logical_xor, + str(mops.Elemwise.Mode.SHL): left_shift, + str(mops.Elemwise.Mode.SHR): right_shift, str(mops.Elemwise.Mode.SWITCH_GT0): relu_grad, + str(mops.Elemwise.Mode.SIGMOID): sigmoid, + str(mops.Elemwise.Mode.SIGMOID_GRAD): sigmoid_grad, + str(mops.Elemwise.Mode.PRELU): prelu, + str(mops.Elemwise.Mode.PRELU_GRAD): prelu_grad, + str(mops.Elemwise.Mode.SILU): silu, + str(mops.Elemwise.Mode.SILU_GRAD): silu_grad, + str(mops.Elemwise.Mode.HSIGMOID): hsigmoid, + str(mops.Elemwise.Mode.HSIGMOID_GRAD): hsigmoid_grad, + str(mops.Elemwise.Mode.H_SWISH): hswish, + str(mops.Elemwise.Mode.H_SWISH_GRAD): hswish_grad, + str(mops.Elemwise.Mode.RELU6): relu6, + str(mops.Elemwise.Mode.RELU6_GRAD): relu6_grad, + str(mops.Elemwise.Mode.LOGSIGMOID): logsigmoid, + str(mops.Elemwise.Mode.SOFTPLUS): softplus, + str(mops.Elemwise.Mode.SOFTPLUS_GRAD): softplus_grad, } diff --git a/imperative/python/megengine/xla/rules/math.py b/imperative/python/megengine/xla/rules/math.py index 818498da8..7e8dd2044 100644 --- a/imperative/python/megengine/xla/rules/math.py +++ b/imperative/python/megengine/xla/rules/math.py @@ -1,8 +1,11 @@ from typing import Sequence, Union +import numpy as np + from ...core._imperative_rt import ops as mops from .. import ir_utils -from ..lib.mlir.dialects import hlo +from ..ir_utils import i64_attr +from ..lib.mlir.dialects import chlo, hlo from .hlotensor import HLOTensor from .utils import _can_broadcast_to, _shape_equal, register_lower_rule @@ -236,3 +239,7 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype), ).result ).transpose(permutation) + + +def topk(inp, k, descending=True, kth_only=False, no_sort=False): + return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results] diff --git a/imperative/python/megengine/xla/rules/reduction.py b/imperative/python/megengine/xla/rules/reduction.py index 05f4022a4..a30c22dc1 100644 --- a/imperative/python/megengine/xla/rules/reduction.py +++ b/imperative/python/megengine/xla/rules/reduction.py @@ -51,11 +51,15 @@ def _get_bitwise_or_identity(dtype) -> np.ndarray: return np.array(0, dtype) -def _infer_reduce_shape(ishape, axes, keepdims=False): +def _normalize_reduce_axes(ishape, axes): axes = list(range(len(ishape))) if axes is None else axes axes = [axes] if isinstance(axes, int) else axes axes = [axis if axis >= 0 else axis + len(ishape) for axis in axes] + return axes + +def _infer_reduce_shape(ishape, axes, keepdims=False): + axes = _normalize_reduce_axes(ishape, axes) reduced_shape = [] for axis, length in enumerate(ishape): @@ -89,8 +93,7 @@ def _reduce( return HLOTensor(reduce_op.result) - axes = [axes] if isinstance(axes, int) else axes - axes = [axis if axis >= 0 else axis + inp.ndim for axis in axes] + axes = _normalize_reduce_axes(inp.shape, axes) maykeepdim_shape = _infer_reduce_shape(inp.shape, axes, keepdims) _check_shape(maykeepdim_shape, oshape) @@ -110,6 +113,7 @@ any = partial(_reduce, hlo.OrOp, _get_bitwise_or_identity) def mean(inp, axes=None, keepdims=False): + axes = _normalize_reduce_axes(inp.shape, axes) inp_sum = sum(inp, axes, keepdims) inp_shape = inp.shape diff --git a/imperative/python/megengine/xla/rules/tensor.py b/imperative/python/megengine/xla/rules/tensor.py index 4c31f19cc..64e21510c 100644 --- a/imperative/python/megengine/xla/rules/tensor.py +++ b/imperative/python/megengine/xla/rules/tensor.py @@ -226,6 +226,11 @@ def pad(inp, pad_value, padding): ) +def where(mask, x, y): + mask = mask.astype("float32") + return mask * x + (1.0 - mask) * y + + @register_lower_rule(mops.Reshape) def reshape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): assert len(args) == 2 diff --git a/imperative/python/megengine/xla/rules/trivial.py b/imperative/python/megengine/xla/rules/trivial.py index 558ad7a6c..b8b52fbaf 100644 --- a/imperative/python/megengine/xla/rules/trivial.py +++ b/imperative/python/megengine/xla/rules/trivial.py @@ -5,6 +5,7 @@ import numpy as np from ...core._imperative_rt import ops as mops from ..lib.mlir import ir from .hlotensor import HLOTensor +from .tensor import fill from .utils import _check_shape, register_lower_rule @@ -51,3 +52,8 @@ def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): assert len(args) == 1 return args + + +@register_lower_rule("fake_op_rule_for_debug") +def fake_op_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): + return [fill(0.0, out.shape, out.dtype) for out in ctx.vars_out] diff --git a/imperative/python/megengine/xla/rules/utils.py b/imperative/python/megengine/xla/rules/utils.py index 7634db5a5..3622325aa 100644 --- a/imperative/python/megengine/xla/rules/utils.py +++ b/imperative/python/megengine/xla/rules/utils.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from ..lib.mlir import ir @@ -19,10 +21,16 @@ def register_lower_rule(*ops): return decorator -def get_rule(op): - if isinstance(op, str): - return lower_rule[op] - return lower_rule[type(op)] +def get_rule(op, use_fake_rule_for_debug=False): + op_key = op if isinstance(op, str) else type(op) + if use_fake_rule_for_debug: + if op_key in lower_rule: + return lower_rule[op_key] + else: + warnings.warn(f"op: {op_key} not register, use fake op rule") + return lower_rule["fake_op_rule_for_debug"] + else: + return lower_rule[op_key] def _log_mge_opr_attrs(mopr): diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 9579dca3f..23d2f36bf 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -81,7 +81,7 @@ ValueRef make_empty_tensor( storage.ensure_size(dtype->size()); std::memset(storage.ptr(), 0, dtype->size()); auto t = imperative::apply( - CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), + CreateTensor(CreateTensor::Const, *device, *dtype, ValueShape()), HostStorage::make(storage))[0]; auto res = broadcast_to(t, shape); return res; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 0ec0b3f40..556466515 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1321,7 +1321,8 @@ void init_tensor(py::module m) { } else if (self.check_external) { throw std::runtime_error( "have some unknown input tensors in trace " - "result"); + "result, maybe you can set " + "`capture_as_const=True` when trace"); } } } diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 6c68d52c2..12f46edaf 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -848,7 +848,10 @@ def test_trace_without_error(): c = tensor([3.0]) fwd(a, b, c) except Exception as e: - assert str(e) == "have some unknown input tensors in trace result" + assert ( + str(e) + == "have some unknown input tensors in trace result, maybe you can set `capture_as_const=True` when trace" + ) else: assert False diff --git a/imperative/python/test/unit/xla/functional/test_xla_elemwise.py b/imperative/python/test/unit/xla/functional/test_xla_elemwise.py index 30a0a7b09..af54b832c 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_elemwise.py +++ b/imperative/python/test/unit/xla/functional/test_xla_elemwise.py @@ -18,94 +18,107 @@ def test_elemwise(): np.random.seed(123) mge.random.seed(123) - def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5): + def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5, **kwargs): dtype = dtype or np.float32 - inps = [ - tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype) - for inp_shape in inp_shapes - ] - doup = tensor(0.1 * np.random.randn(*felemwise(*inps).shape), dtype=dtype) + if dtype in [np.int16, np.int32, np.uint16, np.uint32]: + inps = [ + tensor(np.random.randint(0, 10, size=inp_shape), dtype=dtype) + for inp_shape in inp_shapes + ] + else: + inps = [ + tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype) + for inp_shape in inp_shapes + ] + doup = tensor( + 0.1 * np.random.randn(*felemwise(*inps, **kwargs).shape), dtype=dtype + ) gm = GradManager() @jit.xla_trace(without_host=True) def func(inps, doup): - gm.attach(inps) - with gm: - oup = felemwise(*inps) - if backward: + if backward: + gm.attach(inps) + with gm: + oup = felemwise(*inps, **kwargs) gm.backward(oup, doup) return [oup, *[inp.grad for inp in inps]] - else: - return [oup] + else: + oup = felemwise(*inps, **kwargs) + return [oup] mge_rsts = func(inps, doup) xla_rsts = func(inps, doup) - for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + for _, (mge_rst, xla_rst) in enumerate(zip(mge_rsts, xla_rsts)): np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol) tester(F.neg, (4, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.abs, (2, 32, 16), dtype=np.float32, atol=1e-5) - tester(F.tanh, (4, 16, 3, 1), backward=False, dtype=np.float32, atol=1e-5) + tester(F.sin, (1, 16, 3, 1), dtype=np.float32, atol=1e-5) + tester(F.cos, (4, 16, 3), dtype=np.float32, atol=1e-5) + tester(F.tan, (4, 16, 1), dtype=np.float32, atol=1e-5) + tester(F.sinh, (4, 16, 1), dtype=np.float32, atol=1e-5) + tester(F.cosh, (3, 16, 1), dtype=np.float32, atol=1e-5) + tester(F.tanh, (4, 6, 3, 1), dtype=np.float32, atol=5e-4) + tester(F.asin, (4, 1, 3, 1), dtype=np.float32, atol=1e-5) + # tester(F.acos, (4, 16, 3, 1), dtype=np.float32, atol=1e-5) # xla compute error + tester(F.atan, (4, 16, 3, 1), dtype=np.float32, atol=1e-5) + tester(F.asinh, (4, 1, 3, 1), dtype=np.float32, atol=1e-5) + tester(F.acosh, (4, 1), dtype=np.float32, atol=1e-5) + tester(F.atanh, (1,), dtype=np.float32, atol=1e-5) tester(F.exp, (2, 8), dtype=np.float32, atol=1e-5) tester(F.sqrt, (32,), dtype=np.float32, atol=1e-5) + tester(F.square, (32,), dtype=np.float32, atol=1e-5) tester(F.log, (8, 8, 16), dtype=np.float32, atol=1e-5) + tester(F.log1p, (8, 1, 16), dtype=np.float32, atol=1e-5) + tester(F.expm1, (6, 8, 2), dtype=np.float32, atol=1e-5) + tester(F.floor, (4, 16, 1, 1), backward=False, dtype=np.float32, atol=1e-5) + tester(F.ceil, (4, 1, 1), backward=False, dtype=np.float32, atol=1e-5) + tester(F.round, (1, 4, 1), backward=False, dtype=np.float32, atol=1e-5) + tester(F.clip, (4, 16, 1), dtype=np.float32, atol=1e-5, lower=-1.0, upper=1.0) tester(F.relu, (1,), dtype=np.float32, atol=1e-5) tester(F.gelu, (4, 16, 12, 12), dtype=np.float32, atol=2e-5) - + tester(F.sigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5) + tester(F.hsigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5) + tester(F.hswish, (4, 16, 16, 12), dtype=np.float32, atol=1e-5) + tester(F.relu6, (12, 16, 1), dtype=np.float32, atol=1e-5) + tester(F.leaky_relu, (1, 16, 1), dtype=np.float32, atol=1e-5) + tester(F.leaky_relu, (12, 16, 1), dtype=np.float32, atol=1e-5, negative_slope=0.5) + tester(F.silu, (4, 16, 12, 12), dtype=np.float32, atol=1e-5) + tester(F.logsigmoid, (4, 16, 12, 12), dtype=np.float32, atol=1e-5) + tester(F.softplus, (4, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.add, (4, 16, 12, 12), (4, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.sub, (4, 16, 12, 12), (4, 16, 1, 1), dtype=np.float32, atol=1e-5) tester(F.mul, (4, 16, 12, 12), (1, 1, 12, 12), dtype=np.float32, atol=1e-5) - tester( - F.div, - (4, 16, 1, 1), - (4, 16, 12, 12), - backward=False, - dtype=np.float32, - atol=1e-5, - ) - tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=1e-5) + tester(F.div, (4, 16, 1, 1), (4, 16, 12, 12), atol=5e-4) + tester(F.floor_div, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, atol=5e-5) + # tester(F.mod, (8, 1, 4), (8, 1, 1), backward=False, dtype=np.int32, atol=1e-5) # xla not support + tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=5e-5) + tester(F.prelu, (4, 16, 12, 12), (1,), dtype=np.float32, atol=1e-5) + tester(F.prelu, (16, 5, 12), (1, 5, 1), dtype=np.float32, atol=1e-5) + tester(F.logaddexp, (16, 5, 12), (1, 5, 12), dtype=np.float32, atol=1e-5) + tester(F.maximum, (1, 5, 1), (1, 5, 12), dtype=np.float32, atol=1e-5) + tester(F.minimum, (1, 5, 12), (16, 5, 12), dtype=np.float32, atol=1e-5) tester( - F.equal, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.float32, atol=1e-5 - ) - tester( - F.not_equal, - (4, 16, 12, 12), - (4, 16, 1, 1), - backward=False, - dtype=np.float32, - atol=1e-5, - ) - tester( - F.greater, - (4, 16, 1, 1), - (4, 16, 12, 12), - backward=False, - dtype=np.float32, - atol=1e-5, + F.left_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32 ) tester( - F.greater_equal, - (16, 1, 1), - (4, 16, 12, 12), - backward=False, - dtype=np.float32, - atol=1e-5, - ) - tester( - F.less, - (4, 16, 12, 1), - (4, 16, 12, 12), - backward=False, - dtype=np.float32, - atol=1e-5, - ) - tester( - F.less_equal, - (1, 1, 12, 12), - (4, 16, 12, 12), - backward=False, - dtype=np.float32, - atol=1e-5, + F.right_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32 ) + + tester(F.equal, (4, 16, 12, 12), (1, 1), backward=False) + tester(F.not_equal, (4, 16, 12, 12), (4, 16, 1, 1), backward=False) + tester(F.greater, (4, 16, 1, 1), (4, 16, 12, 12), backward=False) + tester(F.greater_equal, (16, 1, 1), (4, 16, 12, 12), backward=False) + tester(F.less, (4, 16, 12, 1), (4, 16, 12, 12), backward=False) + tester(F.less_equal, (1, 1, 12, 12), (4, 16, 12, 12), backward=False) + + # bool is not support in dlpack now + # tester(F.logical_and, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.bool8) + # tester(F.logical_or, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, dtype=np.bool8) + # tester( + # F.logical_xor, (4, 16, 1, 1), (4, 16, 12, 12), backward=False, dtype=np.bool8 + # ) + # tester(F.logical_not, (16, 1, 1), backward=False, dtype=np.bool8) diff --git a/imperative/python/test/unit/xla/functional/test_xla_nn.py b/imperative/python/test/unit/xla/functional/test_xla_nn.py index 0188d82cb..2dba08a32 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_nn.py +++ b/imperative/python/test/unit/xla/functional/test_xla_nn.py @@ -258,3 +258,70 @@ def test_softmax(): tester((32, 16, 5), 0) tester((1, 16, 5), -1) tester((14, 1, 13, 5), 1) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_loss(): + def tester( + loss_fn, + pred_shape, + label_shape, + label_type="default", + atol=1e-5, + dtype=None, + **kwargs + ): + dtype = dtype or np.float32 + pred = tensor(np.random.randn(*pred_shape), dtype=dtype) + if label_type == "default": + label = tensor(np.random.randn(*label_shape), dtype=dtype) + elif label_type == "classes": + label = tensor(np.random.randint(0, 10, size=label_shape), dtype=dtype) + dout = tensor(np.random.randn(1,), dtype=dtype) + + gm = autodiff.GradManager() + + @jit.xla_trace(without_host=True) + def func(pred, label, dout): + gm.attach([pred]) + with gm: + out = loss_fn(pred, label, **kwargs) + gm.backward(out, dout) + return out, pred.grad + + mge_rsts = func(pred, label, dout) + xla_rsts = func(pred, label, dout) + + for idx, (mge_rst, xla_rst) in enumerate(zip(mge_rsts, xla_rsts)): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol) + + from megengine.functional import loss + + tester(loss.l1_loss, (32, 16, 8, 8), (32, 16, 8, 8)) + tester(loss.l1_loss, (1, 16), (1, 16)) + tester(loss.square_loss, (32, 16, 8, 8), (32, 16, 8, 8)) + tester(loss.square_loss, (16, 1), (16, 1)) + tester( + loss.cross_entropy, + (16, 32), + (16,), + label_type="classes", + axis=1, + with_logits=True, + label_smooth=0.0, + ) + tester( + loss.cross_entropy, + (16, 32), + (32,), + label_type="classes", + axis=0, + with_logits=False, + label_smooth=0.5, + ) + tester(loss.binary_cross_entropy, (16, 32, 4, 8), (16, 32, 4, 8), with_logits=True) + tester(loss.binary_cross_entropy, (1, 32, 1), (1, 32, 1), with_logits=False) + tester(loss.hinge_loss, (32, 16, 8, 8), (32, 16, 8, 8), norm="L1") + tester(loss.hinge_loss, (1, 16, 1, 1), (1, 16, 1, 1), norm="L2") diff --git a/imperative/python/test/unit/xla/module/test_elemwise.py b/imperative/python/test/unit/xla/module/test_elemwise.py new file mode 100644 index 000000000..585cc8e6d --- /dev/null +++ b/imperative/python/test/unit/xla/module/test_elemwise.py @@ -0,0 +1,49 @@ +import platform + +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.tensor as tensor +from megengine import is_cuda_available, jit +from megengine.autodiff import GradManager +from megengine.optimizer import Adam + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_elemwise_activation(): + def tester(TestMod, ishape, dtype=None, atol=1e-5, **kwargs): + dtype = dtype or np.float32 + inp = tensor(0.1 * np.random.randn(*ishape), dtype=dtype) + doup = tensor(0.1 * np.random.randn(*ishape), dtype=dtype) + + gm = GradManager() + mod = TestMod(**kwargs) + + @jit.xla_trace(without_host=True) + def func(mod, inp, doup): + gm.attach(inp) + with gm: + oup = mod(inp) + gm.backward(oup, doup) + return oup, inp.grad + + mge_rsts = func(mod, inp, doup) + xla_rsts = func(mod, inp, doup) + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol) + + tester(M.Sigmoid, (2, 3, 4, 5)) + tester(M.ReLU, (2, 3,)) + tester(M.LeakyReLU, (4, 5)) + tester(M.LeakyReLU, (4, 5), negative_slope=0.3) + tester(M.PReLU, (8, 6, 5)) + tester(M.PReLU, (8, 6, 5, 7), num_parameters=6, init=0.1) + tester(M.PReLU, (1,)) + tester(M.SiLU, (4, 8, 3, 2)) + tester(M.SiLU, (1, 1,)) + tester(M.GELU, (1, 1, 2)) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 45ab698f6..cb5077ab4 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -564,6 +564,7 @@ REGISTE_PARAM_JSON_FUNC(LayerNormBackward) REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward) REGISTE_PARAM_JSON_FUNC(DropoutBackward) REGISTE_PARAM_JSON_FUNC(SoftmaxBackward) +REGISTE_PARAM_JSON_FUNC(ArgsortBackward) std::shared_ptr dimshuffle_param2json( const opr::Dimshuffle::Param& param) { @@ -862,6 +863,7 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); + add_single_param_json(); #endif } -- GitLab