提交 4c7905f3 编写于 作者: M Megvii Engine Team

feat(imperative): add some xla op rules

GitOrigin-RevId: 0650c75dc1e4ec9af8ae7d9ed3eca60e4681e04a
上级 0d2b4db9
...@@ -75,7 +75,9 @@ def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map): ...@@ -75,7 +75,9 @@ def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map):
def check_external(trace_obj): def check_external(trace_obj):
for var in trace_obj.vars: for var in trace_obj.vars:
if var.kind == "external" and not var.inp_mark: if var.kind == "external" and not var.inp_mark:
raise RuntimeError("have unknown input in trace result") raise RuntimeError(
"have unknown input in trace result, maybe you can set `capture_as_const=True` when trace"
)
check_external(fwd) check_external(fwd)
check_external(bwd) check_external(bwd)
......
...@@ -579,7 +579,7 @@ class trace: ...@@ -579,7 +579,7 @@ class trace:
if not self._trace.compiled(): if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs) outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist): for i, out in enumerate(outlist):
assert isinstance(out, RawTensor) assert isinstance(out, RawTensor), f"get out of type {type(out)}"
outlist[i] = get_marked_output_tensor(self.output_num, out) outlist[i] = get_marked_output_tensor(self.output_num, out)
del out del out
self.out_list.append(self.output_num) self.out_list.append(self.output_num)
......
...@@ -101,10 +101,11 @@ class DropoutMaskCanonicalizer(Pass): ...@@ -101,10 +101,11 @@ class DropoutMaskCanonicalizer(Pass):
if not isinstance(eqn.op, mops.Dropout): if not isinstance(eqn.op, mops.Dropout):
continue continue
outputs = list(eqn.outputs) inputs, outputs = list(eqn.inputs), list(eqn.outputs)
mask_var = tr.vars[outputs[1]] mask_var = tr.vars[outputs[1]]
inp_shape = tr.vars[inputs[0]].shape
new_mask_var = AbstractVar( new_mask_var = AbstractVar(
mask_var.id, (int(np.prod(mask_var.shape)) * 8,), mask_var.dtype mask_var.id, (int(np.prod(inp_shape)),), mask_var.dtype
) )
tr.vars[mask_var.id] = new_mask_var tr.vars[mask_var.id] = new_mask_var
......
...@@ -142,7 +142,7 @@ def lowering_ops( ...@@ -142,7 +142,7 @@ def lowering_ops(
vars_out=[trace_result.vars[oup] for oup in eqn.outputs], vars_out=[trace_result.vars[oup] for oup in eqn.outputs],
param=eqn.param, param=eqn.param,
) )
rule = get_rule(eqn.op) rule = get_rule(eqn.op, use_fake_rule_for_debug=False)
in_nodes = read(eqn.inputs) in_nodes = read(eqn.inputs)
hinps = [ hinps = [
......
...@@ -18,9 +18,9 @@ def _infer_elemwise_oshape(inp_shapes): ...@@ -18,9 +18,9 @@ def _infer_elemwise_oshape(inp_shapes):
if len(rhs_shape) == 0: if len(rhs_shape) == 0:
return lhs_shape return lhs_shape
if np.prod(lhs_shape) == 1 and len(rhs_shape) != 0: if np.prod(lhs_shape) == 1 and len(lhs_shape) == 1 and len(rhs_shape) != 0:
return rhs_shape return rhs_shape
if np.prod(rhs_shape) == 1 and len(rhs_shape) != 0: if np.prod(rhs_shape) == 1 and len(rhs_shape) == 1 and len(rhs_shape) != 0:
return lhs_shape return lhs_shape
oshape = [] oshape = []
...@@ -62,6 +62,24 @@ def _infer_elemwise_odtype(inp_dtypes): ...@@ -62,6 +62,24 @@ def _infer_elemwise_odtype(inp_dtypes):
return oup_dtype return oup_dtype
def bitcast(inp, oshape, odtype):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype
return HLOTensor(
hlo.BitcastConvertOp(
ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor
).result
)
def typecvt(inp, odtype):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype
return HLOTensor(
hlo.ConvertOp(
ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor
).result
)
def _compare(lhs, rhs, mode, comparison_type=None): def _compare(lhs, rhs, mode, comparison_type=None):
""" """
mod: can be mod: can be
...@@ -126,19 +144,36 @@ def _elemwise_binary(hlo_op, a, b): ...@@ -126,19 +144,36 @@ def _elemwise_binary(hlo_op, a, b):
return _elemwise(hlo_op, [a, b]) return _elemwise(hlo_op, [a, b])
def _elemwise_ternary(hlo_op, a, b, c):
return _elemwise(hlo_op, [a, b, c])
neg = partial(_elemwise_unary, hlo.NegOp) neg = partial(_elemwise_unary, hlo.NegOp)
abs = partial(_elemwise_unary, hlo.AbsOp) abs = partial(_elemwise_unary, hlo.AbsOp)
sin = partial(_elemwise_unary, hlo.SineOp)
cos = partial(_elemwise_unary, hlo.CosineOp)
tanh = partial(_elemwise_unary, hlo.TanhOp) tanh = partial(_elemwise_unary, hlo.TanhOp)
exp = partial(_elemwise_unary, hlo.ExpOp) exp = partial(_elemwise_unary, hlo.ExpOp)
sqrt = partial(_elemwise_unary, hlo.SqrtOp) sqrt = partial(_elemwise_unary, hlo.SqrtOp)
log = partial(_elemwise_unary, hlo.LogOp) log = partial(_elemwise_unary, hlo.LogOp)
log1p = partial(_elemwise_unary, hlo.Log1pOp)
expm1 = partial(_elemwise_unary, hlo.Expm1Op)
floor = partial(_elemwise_unary, hlo.FloorOp)
ceil = partial(_elemwise_unary, hlo.CeilOp)
round = partial(_elemwise_unary, hlo.RoundOp)
add = partial(_elemwise_binary, hlo.AddOp) add = partial(_elemwise_binary, hlo.AddOp)
sub = partial(_elemwise_binary, hlo.SubtractOp) sub = partial(_elemwise_binary, hlo.SubtractOp)
mul = partial(_elemwise_binary, hlo.MulOp) mul = partial(_elemwise_binary, hlo.MulOp)
div = partial(_elemwise_binary, hlo.DivOp) div = partial(_elemwise_binary, hlo.DivOp)
pow = partial(_elemwise_binary, hlo.PowOp) pow = partial(_elemwise_binary, hlo.PowOp)
maximum = partial(_elemwise_binary, hlo.MaxOp)
minimum = partial(_elemwise_binary, hlo.MinOp)
atan2 = partial(_elemwise_binary, hlo.Atan2Op)
left_shift = partial(_elemwise_binary, hlo.ShiftLeftOp)
right_shift = partial(_elemwise_binary, hlo.ShiftRightArithmeticOp)
clip = partial(_elemwise_ternary, hlo.ClampOp)
equal = partial(_compare, mode="EQ") equal = partial(_compare, mode="EQ")
not_equal = partial(_compare, mode="NE") not_equal = partial(_compare, mode="NE")
...@@ -147,31 +182,99 @@ greater_equal = partial(_compare, mode="GE") ...@@ -147,31 +182,99 @@ greater_equal = partial(_compare, mode="GE")
less = partial(_compare, mode="LT") less = partial(_compare, mode="LT")
less_equal = partial(_compare, mode="LE") less_equal = partial(_compare, mode="LE")
logical_and = partial(_elemwise_binary, hlo.AndOp)
logical_or = partial(_elemwise_binary, hlo.OrOp)
logical_not = partial(_elemwise_unary, hlo.NotOp)
logical_xor = partial(_elemwise_binary, hlo.XorOp)
def floor_div(x, y):
return floor(div(x, y))
def mod(x, y):
assert False, "xla not support"
def cond_leq_move(x, y, z):
mask = (x <= y).astype(x.dtype)
return mask * z
def cond_lt_move(x, y, z):
mask = (x < y).astype(x.dtype)
return mask * z
def log_add_exp(x, y):
min_val = minimum(x, y)
max_val = maximum(x, y)
return max_val + log1p(exp(min_val - max_val))
def square(x):
return mul(x, x)
def abs_grad(x, dy): def abs_grad(x, dy):
return (x / abs(x)) * dy return (x / abs(x)) * dy
def tan(x):
return sin(x) / cos(x)
def tan_grad(x, dy):
return (1.0 + tan(x) ** 2.0) * dy
def sinh(x):
return (exp(x) - exp(-x)) / 2.0
def cosh(x):
return (exp(x) + exp(-x)) / 2.0
def tanh_grad(x, dy): def tanh_grad(x, dy):
return (1.0 - tanh(x) ** 2.0) * dy return (1.0 - tanh(x) ** 2.0) * dy
def bitcast(inp, oshape, odtype): def atan(x):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype return atan2(x, 1.0)
return HLOTensor(
hlo.BitcastConvertOp(
ir_utils.make_ir_type_according_meta(oshape, odtype), inp.tensor
).result
)
def typecvt(inp, odtype): def asin(x):
odtype = np.dtype(odtype) if isinstance(odtype, str) else odtype return atan(x / sqrt(1.0 - x ** 2.0))
return HLOTensor(
hlo.ConvertOp(
ir_utils.make_ir_type_according_meta(inp.shape, odtype), inp.tensor def acos(x):
).result assert False, "xla not support"
) # return atan(sqrt(1.0 - x ** 2.0) / x)
def asinh(x):
return log(x + sqrt(x ** 2.0 + 1.0))
def acosh(x):
return log(x + sqrt(x ** 2.0 - 1.0))
def atanh(x):
return log((1.0 + x) / (1.0 - x)) / 2.0
def asinh_grad(x, dy):
return dy / sqrt(x ** 2.0 + 1.0)
def acosh_grad(x, dy):
return dy / sqrt(x ** 2.0 - 1.0)
def atanh_grad(x, dy):
return dy / (1.0 - x ** 2.0)
def gelu(inp, approximate: bool = True): def gelu(inp, approximate: bool = True):
...@@ -257,6 +360,86 @@ def relu_grad(x, dy): ...@@ -257,6 +360,86 @@ def relu_grad(x, dy):
return dy * mask return dy * mask
def sigmoid(inp):
return 1.0 / (1.0 + exp(-inp))
def sigmoid_grad(y, dy):
return y * (1.0 - y) * dy
def hsigmoid(x):
from .tensor import where
return where(x <= -3.0, 0.0, where(x >= 3.0, 1.0, (x + 3.0) / 6.0))
def hsigmoid_grad(x, dy):
from .tensor import where
return where(x <= -3.0, 0.0, where(x >= 3.0, 0.0, dy / 6.0))
def relu6(x):
return clip(x, 0.0, 6.0)
def relu6_grad(x, dy):
from .tensor import where
return where(x <= 0.0, 0.0, where(x >= 6.0, 0.0, dy))
def hswish(x):
return x * minimum(maximum(x + 3.0, 0.0), 6.0) * (1.0 / 6.0)
def hswish_grad(x, dy):
from .tensor import where
return where(x < -3.0, 0.0, where(x > 3.0, dy, (2.0 * x + 3.0) / 6.0 * dy))
def logsigmoid(x):
from .tensor import where
return -log1p(exp(-abs(x))) + where(x >= 0.0, 0.0, x)
def softplus(x):
return log1p(exp(-abs(x))) + relu(x)
def softplus_grad(x, dy):
from .tensor import where
exp_abs = exp(-abs(x))
logg = -dy * exp_abs / (1.0 + exp_abs)
grad0 = where(x > 0.0, logg, -logg)
relux = relu(x)
grad1 = where(relux > 0.0, dy, 0.0)
return grad0 + grad1
def prelu(inp, alpha):
mask = (inp > 0.0).astype(inp.dtype)
return inp * mask + alpha * (1.0 - mask) * inp
def prelu_grad(x, dy, alpha):
mask = (x > 0.0).astype(x.dtype)
return dy * mask + alpha * (1.0 - mask) * dy
def silu(inp):
return inp / (1.0 + exp(-inp))
def silu_grad(x, dy):
xsig = sigmoid(x)
return dy * xsig * (1.0 + x * (1.0 - xsig))
# Elemwise.Mode is unhashable, so we convert it to str # Elemwise.Mode is unhashable, so we convert it to str
mge_elemwise_to_xla = { mge_elemwise_to_xla = {
str(mops.Elemwise.Mode.ADD): add, str(mops.Elemwise.Mode.ADD): add,
...@@ -264,22 +447,71 @@ mge_elemwise_to_xla = { ...@@ -264,22 +447,71 @@ mge_elemwise_to_xla = {
str(mops.Elemwise.Mode.SUB): sub, str(mops.Elemwise.Mode.SUB): sub,
str(mops.Elemwise.Mode.EXP): exp, str(mops.Elemwise.Mode.EXP): exp,
str(mops.Elemwise.Mode.LOG): log, str(mops.Elemwise.Mode.LOG): log,
str(mops.Elemwise.Mode.LOG1P): log1p,
str(mops.Elemwise.Mode.LOG_SUM_EXP): log_add_exp,
str(mops.Elemwise.Mode.MAX): maximum,
str(mops.Elemwise.Mode.MIN): minimum,
str(mops.Elemwise.Mode.COND_LEQ_MOV): cond_leq_move,
str(mops.Elemwise.Mode.COND_LT_MOV): cond_lt_move,
str(mops.Elemwise.Mode.FLOOR): floor,
str(mops.Elemwise.Mode.CEIL): ceil,
str(mops.Elemwise.Mode.ROUND): round,
str(mops.Elemwise.Mode.CLIP): clip,
str(mops.Elemwise.Mode.GELU): gelu, str(mops.Elemwise.Mode.GELU): gelu,
str(mops.Elemwise.Mode.GELU_GRAD): gelu_grad, str(mops.Elemwise.Mode.GELU_GRAD): gelu_grad,
str(mops.Elemwise.Mode.TRUE_DIV): div, str(mops.Elemwise.Mode.TRUE_DIV): div,
str(mops.Elemwise.Mode.NEGATE): neg, str(mops.Elemwise.Mode.NEGATE): neg,
str(mops.Elemwise.Mode.FLOOR_DIV): floor_div,
str(mops.Elemwise.Mode.MOD): mod,
str(mops.Elemwise.Mode.ABS): abs, str(mops.Elemwise.Mode.ABS): abs,
str(mops.Elemwise.Mode.ABS_GRAD): abs_grad, str(mops.Elemwise.Mode.ABS_GRAD): abs_grad,
str(mops.Elemwise.Mode.SIN): sin,
str(mops.Elemwise.Mode.COS): cos,
str(mops.Elemwise.Mode.TAN): tan,
str(mops.Elemwise.Mode.SINH): sinh,
str(mops.Elemwise.Mode.COSH): cosh,
str(mops.Elemwise.Mode.TANH): tanh, str(mops.Elemwise.Mode.TANH): tanh,
str(mops.Elemwise.Mode.ASIN): asin,
str(mops.Elemwise.Mode.ACOS): acos,
str(mops.Elemwise.Mode.ASINH): asinh,
str(mops.Elemwise.Mode.ACOSH): acosh,
str(mops.Elemwise.Mode.ATANH): atanh,
str(mops.Elemwise.Mode.ATAN2): atan2,
str(mops.Elemwise.Mode.TANH_GRAD): tanh_grad, str(mops.Elemwise.Mode.TANH_GRAD): tanh_grad,
str(mops.Elemwise.Mode.ASINH_GRAD): asinh_grad,
str(mops.Elemwise.Mode.ACOSH_GRAD): acosh_grad,
str(mops.Elemwise.Mode.ATANH_GRAD): atanh_grad,
str(mops.Elemwise.Mode.SQRT): sqrt, str(mops.Elemwise.Mode.SQRT): sqrt,
str(mops.Elemwise.Mode.SQUARE): square,
str(mops.Elemwise.Mode.POW): pow, str(mops.Elemwise.Mode.POW): pow,
str(mops.Elemwise.Mode.EXPM1): expm1,
str(mops.Elemwise.Mode.RELU): relu, str(mops.Elemwise.Mode.RELU): relu,
str(mops.Elemwise.Mode.EQ): equal, str(mops.Elemwise.Mode.EQ): equal,
str(mops.Elemwise.Mode.NEQ): not_equal, str(mops.Elemwise.Mode.NEQ): not_equal,
str(mops.Elemwise.Mode.LT): less, str(mops.Elemwise.Mode.LT): less,
str(mops.Elemwise.Mode.LEQ): less_equal, str(mops.Elemwise.Mode.LEQ): less_equal,
str(mops.Elemwise.Mode.AND): logical_and,
str(mops.Elemwise.Mode.OR): logical_or,
str(mops.Elemwise.Mode.NOT): logical_not,
str(mops.Elemwise.Mode.XOR): logical_xor,
str(mops.Elemwise.Mode.SHL): left_shift,
str(mops.Elemwise.Mode.SHR): right_shift,
str(mops.Elemwise.Mode.SWITCH_GT0): relu_grad, str(mops.Elemwise.Mode.SWITCH_GT0): relu_grad,
str(mops.Elemwise.Mode.SIGMOID): sigmoid,
str(mops.Elemwise.Mode.SIGMOID_GRAD): sigmoid_grad,
str(mops.Elemwise.Mode.PRELU): prelu,
str(mops.Elemwise.Mode.PRELU_GRAD): prelu_grad,
str(mops.Elemwise.Mode.SILU): silu,
str(mops.Elemwise.Mode.SILU_GRAD): silu_grad,
str(mops.Elemwise.Mode.HSIGMOID): hsigmoid,
str(mops.Elemwise.Mode.HSIGMOID_GRAD): hsigmoid_grad,
str(mops.Elemwise.Mode.H_SWISH): hswish,
str(mops.Elemwise.Mode.H_SWISH_GRAD): hswish_grad,
str(mops.Elemwise.Mode.RELU6): relu6,
str(mops.Elemwise.Mode.RELU6_GRAD): relu6_grad,
str(mops.Elemwise.Mode.LOGSIGMOID): logsigmoid,
str(mops.Elemwise.Mode.SOFTPLUS): softplus,
str(mops.Elemwise.Mode.SOFTPLUS_GRAD): softplus_grad,
} }
......
from typing import Sequence, Union from typing import Sequence, Union
import numpy as np
from ...core._imperative_rt import ops as mops from ...core._imperative_rt import ops as mops
from .. import ir_utils from .. import ir_utils
from ..lib.mlir.dialects import hlo from ..ir_utils import i64_attr
from ..lib.mlir.dialects import chlo, hlo
from .hlotensor import HLOTensor from .hlotensor import HLOTensor
from .utils import _can_broadcast_to, _shape_equal, register_lower_rule from .utils import _can_broadcast_to, _shape_equal, register_lower_rule
...@@ -236,3 +239,7 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): ...@@ -236,3 +239,7 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype), precision_config=ir_utils.precision_attr(lhs.dtype, rhs.dtype),
).result ).result
).transpose(permutation) ).transpose(permutation)
def topk(inp, k, descending=True, kth_only=False, no_sort=False):
return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results]
...@@ -51,11 +51,15 @@ def _get_bitwise_or_identity(dtype) -> np.ndarray: ...@@ -51,11 +51,15 @@ def _get_bitwise_or_identity(dtype) -> np.ndarray:
return np.array(0, dtype) return np.array(0, dtype)
def _infer_reduce_shape(ishape, axes, keepdims=False): def _normalize_reduce_axes(ishape, axes):
axes = list(range(len(ishape))) if axes is None else axes axes = list(range(len(ishape))) if axes is None else axes
axes = [axes] if isinstance(axes, int) else axes axes = [axes] if isinstance(axes, int) else axes
axes = [axis if axis >= 0 else axis + len(ishape) for axis in axes] axes = [axis if axis >= 0 else axis + len(ishape) for axis in axes]
return axes
def _infer_reduce_shape(ishape, axes, keepdims=False):
axes = _normalize_reduce_axes(ishape, axes)
reduced_shape = [] reduced_shape = []
for axis, length in enumerate(ishape): for axis, length in enumerate(ishape):
...@@ -89,8 +93,7 @@ def _reduce( ...@@ -89,8 +93,7 @@ def _reduce(
return HLOTensor(reduce_op.result) return HLOTensor(reduce_op.result)
axes = [axes] if isinstance(axes, int) else axes axes = _normalize_reduce_axes(inp.shape, axes)
axes = [axis if axis >= 0 else axis + inp.ndim for axis in axes]
maykeepdim_shape = _infer_reduce_shape(inp.shape, axes, keepdims) maykeepdim_shape = _infer_reduce_shape(inp.shape, axes, keepdims)
_check_shape(maykeepdim_shape, oshape) _check_shape(maykeepdim_shape, oshape)
...@@ -110,6 +113,7 @@ any = partial(_reduce, hlo.OrOp, _get_bitwise_or_identity) ...@@ -110,6 +113,7 @@ any = partial(_reduce, hlo.OrOp, _get_bitwise_or_identity)
def mean(inp, axes=None, keepdims=False): def mean(inp, axes=None, keepdims=False):
axes = _normalize_reduce_axes(inp.shape, axes)
inp_sum = sum(inp, axes, keepdims) inp_sum = sum(inp, axes, keepdims)
inp_shape = inp.shape inp_shape = inp.shape
......
...@@ -226,6 +226,11 @@ def pad(inp, pad_value, padding): ...@@ -226,6 +226,11 @@ def pad(inp, pad_value, padding):
) )
def where(mask, x, y):
mask = mask.astype("float32")
return mask * x + (1.0 - mask) * y
@register_lower_rule(mops.Reshape) @register_lower_rule(mops.Reshape)
def reshape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): def reshape_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 2 assert len(args) == 2
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from ...core._imperative_rt import ops as mops from ...core._imperative_rt import ops as mops
from ..lib.mlir import ir from ..lib.mlir import ir
from .hlotensor import HLOTensor from .hlotensor import HLOTensor
from .tensor import fill
from .utils import _check_shape, register_lower_rule from .utils import _check_shape, register_lower_rule
...@@ -51,3 +52,8 @@ def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): ...@@ -51,3 +52,8 @@ def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]): def rename_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
assert len(args) == 1 assert len(args) == 1
return args return args
@register_lower_rule("fake_op_rule_for_debug")
def fake_op_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
return [fill(0.0, out.shape, out.dtype) for out in ctx.vars_out]
import warnings
import numpy as np import numpy as np
from ..lib.mlir import ir from ..lib.mlir import ir
...@@ -19,10 +21,16 @@ def register_lower_rule(*ops): ...@@ -19,10 +21,16 @@ def register_lower_rule(*ops):
return decorator return decorator
def get_rule(op): def get_rule(op, use_fake_rule_for_debug=False):
if isinstance(op, str): op_key = op if isinstance(op, str) else type(op)
return lower_rule[op] if use_fake_rule_for_debug:
return lower_rule[type(op)] if op_key in lower_rule:
return lower_rule[op_key]
else:
warnings.warn(f"op: {op_key} not register, use fake op rule")
return lower_rule["fake_op_rule_for_debug"]
else:
return lower_rule[op_key]
def _log_mge_opr_attrs(mopr): def _log_mge_opr_attrs(mopr):
......
...@@ -81,7 +81,7 @@ ValueRef make_empty_tensor( ...@@ -81,7 +81,7 @@ ValueRef make_empty_tensor(
storage.ensure_size(dtype->size()); storage.ensure_size(dtype->size());
std::memset(storage.ptr(), 0, dtype->size()); std::memset(storage.ptr(), 0, dtype->size());
auto t = imperative::apply( auto t = imperative::apply(
CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()), CreateTensor(CreateTensor::Const, *device, *dtype, ValueShape()),
HostStorage::make(storage))[0]; HostStorage::make(storage))[0];
auto res = broadcast_to(t, shape); auto res = broadcast_to(t, shape);
return res; return res;
......
...@@ -1321,7 +1321,8 @@ void init_tensor(py::module m) { ...@@ -1321,7 +1321,8 @@ void init_tensor(py::module m) {
} else if (self.check_external) { } else if (self.check_external) {
throw std::runtime_error( throw std::runtime_error(
"have some unknown input tensors in trace " "have some unknown input tensors in trace "
"result"); "result, maybe you can set "
"`capture_as_const=True` when trace");
} }
} }
} }
......
...@@ -848,7 +848,10 @@ def test_trace_without_error(): ...@@ -848,7 +848,10 @@ def test_trace_without_error():
c = tensor([3.0]) c = tensor([3.0])
fwd(a, b, c) fwd(a, b, c)
except Exception as e: except Exception as e:
assert str(e) == "have some unknown input tensors in trace result" assert (
str(e)
== "have some unknown input tensors in trace result, maybe you can set `capture_as_const=True` when trace"
)
else: else:
assert False assert False
......
...@@ -18,94 +18,107 @@ def test_elemwise(): ...@@ -18,94 +18,107 @@ def test_elemwise():
np.random.seed(123) np.random.seed(123)
mge.random.seed(123) mge.random.seed(123)
def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5): def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5, **kwargs):
dtype = dtype or np.float32 dtype = dtype or np.float32
inps = [ if dtype in [np.int16, np.int32, np.uint16, np.uint32]:
tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype) inps = [
for inp_shape in inp_shapes tensor(np.random.randint(0, 10, size=inp_shape), dtype=dtype)
] for inp_shape in inp_shapes
doup = tensor(0.1 * np.random.randn(*felemwise(*inps).shape), dtype=dtype) ]
else:
inps = [
tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype)
for inp_shape in inp_shapes
]
doup = tensor(
0.1 * np.random.randn(*felemwise(*inps, **kwargs).shape), dtype=dtype
)
gm = GradManager() gm = GradManager()
@jit.xla_trace(without_host=True) @jit.xla_trace(without_host=True)
def func(inps, doup): def func(inps, doup):
gm.attach(inps) if backward:
with gm: gm.attach(inps)
oup = felemwise(*inps) with gm:
if backward: oup = felemwise(*inps, **kwargs)
gm.backward(oup, doup) gm.backward(oup, doup)
return [oup, *[inp.grad for inp in inps]] return [oup, *[inp.grad for inp in inps]]
else: else:
return [oup] oup = felemwise(*inps, **kwargs)
return [oup]
mge_rsts = func(inps, doup) mge_rsts = func(inps, doup)
xla_rsts = func(inps, doup) xla_rsts = func(inps, doup)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): for _, (mge_rst, xla_rst) in enumerate(zip(mge_rsts, xla_rsts)):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol) np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol)
tester(F.neg, (4, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.neg, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
tester(F.abs, (2, 32, 16), dtype=np.float32, atol=1e-5) tester(F.abs, (2, 32, 16), dtype=np.float32, atol=1e-5)
tester(F.tanh, (4, 16, 3, 1), backward=False, dtype=np.float32, atol=1e-5) tester(F.sin, (1, 16, 3, 1), dtype=np.float32, atol=1e-5)
tester(F.cos, (4, 16, 3), dtype=np.float32, atol=1e-5)
tester(F.tan, (4, 16, 1), dtype=np.float32, atol=1e-5)
tester(F.sinh, (4, 16, 1), dtype=np.float32, atol=1e-5)
tester(F.cosh, (3, 16, 1), dtype=np.float32, atol=1e-5)
tester(F.tanh, (4, 6, 3, 1), dtype=np.float32, atol=5e-4)
tester(F.asin, (4, 1, 3, 1), dtype=np.float32, atol=1e-5)
# tester(F.acos, (4, 16, 3, 1), dtype=np.float32, atol=1e-5) # xla compute error
tester(F.atan, (4, 16, 3, 1), dtype=np.float32, atol=1e-5)
tester(F.asinh, (4, 1, 3, 1), dtype=np.float32, atol=1e-5)
tester(F.acosh, (4, 1), dtype=np.float32, atol=1e-5)
tester(F.atanh, (1,), dtype=np.float32, atol=1e-5)
tester(F.exp, (2, 8), dtype=np.float32, atol=1e-5) tester(F.exp, (2, 8), dtype=np.float32, atol=1e-5)
tester(F.sqrt, (32,), dtype=np.float32, atol=1e-5) tester(F.sqrt, (32,), dtype=np.float32, atol=1e-5)
tester(F.square, (32,), dtype=np.float32, atol=1e-5)
tester(F.log, (8, 8, 16), dtype=np.float32, atol=1e-5) tester(F.log, (8, 8, 16), dtype=np.float32, atol=1e-5)
tester(F.log1p, (8, 1, 16), dtype=np.float32, atol=1e-5)
tester(F.expm1, (6, 8, 2), dtype=np.float32, atol=1e-5)
tester(F.floor, (4, 16, 1, 1), backward=False, dtype=np.float32, atol=1e-5)
tester(F.ceil, (4, 1, 1), backward=False, dtype=np.float32, atol=1e-5)
tester(F.round, (1, 4, 1), backward=False, dtype=np.float32, atol=1e-5)
tester(F.clip, (4, 16, 1), dtype=np.float32, atol=1e-5, lower=-1.0, upper=1.0)
tester(F.relu, (1,), dtype=np.float32, atol=1e-5) tester(F.relu, (1,), dtype=np.float32, atol=1e-5)
tester(F.gelu, (4, 16, 12, 12), dtype=np.float32, atol=2e-5) tester(F.gelu, (4, 16, 12, 12), dtype=np.float32, atol=2e-5)
tester(F.sigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
tester(F.hsigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
tester(F.hswish, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
tester(F.relu6, (12, 16, 1), dtype=np.float32, atol=1e-5)
tester(F.leaky_relu, (1, 16, 1), dtype=np.float32, atol=1e-5)
tester(F.leaky_relu, (12, 16, 1), dtype=np.float32, atol=1e-5, negative_slope=0.5)
tester(F.silu, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
tester(F.logsigmoid, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
tester(F.softplus, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
tester(F.add, (4, 16, 12, 12), (4, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.add, (4, 16, 12, 12), (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
tester(F.sub, (4, 16, 12, 12), (4, 16, 1, 1), dtype=np.float32, atol=1e-5) tester(F.sub, (4, 16, 12, 12), (4, 16, 1, 1), dtype=np.float32, atol=1e-5)
tester(F.mul, (4, 16, 12, 12), (1, 1, 12, 12), dtype=np.float32, atol=1e-5) tester(F.mul, (4, 16, 12, 12), (1, 1, 12, 12), dtype=np.float32, atol=1e-5)
tester( tester(F.div, (4, 16, 1, 1), (4, 16, 12, 12), atol=5e-4)
F.div, tester(F.floor_div, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, atol=5e-5)
(4, 16, 1, 1), # tester(F.mod, (8, 1, 4), (8, 1, 1), backward=False, dtype=np.int32, atol=1e-5) # xla not support
(4, 16, 12, 12), tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=5e-5)
backward=False, tester(F.prelu, (4, 16, 12, 12), (1,), dtype=np.float32, atol=1e-5)
dtype=np.float32, tester(F.prelu, (16, 5, 12), (1, 5, 1), dtype=np.float32, atol=1e-5)
atol=1e-5, tester(F.logaddexp, (16, 5, 12), (1, 5, 12), dtype=np.float32, atol=1e-5)
) tester(F.maximum, (1, 5, 1), (1, 5, 12), dtype=np.float32, atol=1e-5)
tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=1e-5) tester(F.minimum, (1, 5, 12), (16, 5, 12), dtype=np.float32, atol=1e-5)
tester( tester(
F.equal, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.float32, atol=1e-5 F.left_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32
)
tester(
F.not_equal,
(4, 16, 12, 12),
(4, 16, 1, 1),
backward=False,
dtype=np.float32,
atol=1e-5,
)
tester(
F.greater,
(4, 16, 1, 1),
(4, 16, 12, 12),
backward=False,
dtype=np.float32,
atol=1e-5,
) )
tester( tester(
F.greater_equal, F.right_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32
(16, 1, 1),
(4, 16, 12, 12),
backward=False,
dtype=np.float32,
atol=1e-5,
)
tester(
F.less,
(4, 16, 12, 1),
(4, 16, 12, 12),
backward=False,
dtype=np.float32,
atol=1e-5,
)
tester(
F.less_equal,
(1, 1, 12, 12),
(4, 16, 12, 12),
backward=False,
dtype=np.float32,
atol=1e-5,
) )
tester(F.equal, (4, 16, 12, 12), (1, 1), backward=False)
tester(F.not_equal, (4, 16, 12, 12), (4, 16, 1, 1), backward=False)
tester(F.greater, (4, 16, 1, 1), (4, 16, 12, 12), backward=False)
tester(F.greater_equal, (16, 1, 1), (4, 16, 12, 12), backward=False)
tester(F.less, (4, 16, 12, 1), (4, 16, 12, 12), backward=False)
tester(F.less_equal, (1, 1, 12, 12), (4, 16, 12, 12), backward=False)
# bool is not support in dlpack now
# tester(F.logical_and, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.bool8)
# tester(F.logical_or, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, dtype=np.bool8)
# tester(
# F.logical_xor, (4, 16, 1, 1), (4, 16, 12, 12), backward=False, dtype=np.bool8
# )
# tester(F.logical_not, (16, 1, 1), backward=False, dtype=np.bool8)
...@@ -258,3 +258,70 @@ def test_softmax(): ...@@ -258,3 +258,70 @@ def test_softmax():
tester((32, 16, 5), 0) tester((32, 16, 5), 0)
tester((1, 16, 5), -1) tester((1, 16, 5), -1)
tester((14, 1, 13, 5), 1) tester((14, 1, 13, 5), 1)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_loss():
def tester(
loss_fn,
pred_shape,
label_shape,
label_type="default",
atol=1e-5,
dtype=None,
**kwargs
):
dtype = dtype or np.float32
pred = tensor(np.random.randn(*pred_shape), dtype=dtype)
if label_type == "default":
label = tensor(np.random.randn(*label_shape), dtype=dtype)
elif label_type == "classes":
label = tensor(np.random.randint(0, 10, size=label_shape), dtype=dtype)
dout = tensor(np.random.randn(1,), dtype=dtype)
gm = autodiff.GradManager()
@jit.xla_trace(without_host=True)
def func(pred, label, dout):
gm.attach([pred])
with gm:
out = loss_fn(pred, label, **kwargs)
gm.backward(out, dout)
return out, pred.grad
mge_rsts = func(pred, label, dout)
xla_rsts = func(pred, label, dout)
for idx, (mge_rst, xla_rst) in enumerate(zip(mge_rsts, xla_rsts)):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol)
from megengine.functional import loss
tester(loss.l1_loss, (32, 16, 8, 8), (32, 16, 8, 8))
tester(loss.l1_loss, (1, 16), (1, 16))
tester(loss.square_loss, (32, 16, 8, 8), (32, 16, 8, 8))
tester(loss.square_loss, (16, 1), (16, 1))
tester(
loss.cross_entropy,
(16, 32),
(16,),
label_type="classes",
axis=1,
with_logits=True,
label_smooth=0.0,
)
tester(
loss.cross_entropy,
(16, 32),
(32,),
label_type="classes",
axis=0,
with_logits=False,
label_smooth=0.5,
)
tester(loss.binary_cross_entropy, (16, 32, 4, 8), (16, 32, 4, 8), with_logits=True)
tester(loss.binary_cross_entropy, (1, 32, 1), (1, 32, 1), with_logits=False)
tester(loss.hinge_loss, (32, 16, 8, 8), (32, 16, 8, 8), norm="L1")
tester(loss.hinge_loss, (1, 16, 1, 1), (1, 16, 1, 1), norm="L2")
import platform
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.tensor as tensor
from megengine import is_cuda_available, jit
from megengine.autodiff import GradManager
from megengine.optimizer import Adam
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_elemwise_activation():
def tester(TestMod, ishape, dtype=None, atol=1e-5, **kwargs):
dtype = dtype or np.float32
inp = tensor(0.1 * np.random.randn(*ishape), dtype=dtype)
doup = tensor(0.1 * np.random.randn(*ishape), dtype=dtype)
gm = GradManager()
mod = TestMod(**kwargs)
@jit.xla_trace(without_host=True)
def func(mod, inp, doup):
gm.attach(inp)
with gm:
oup = mod(inp)
gm.backward(oup, doup)
return oup, inp.grad
mge_rsts = func(mod, inp, doup)
xla_rsts = func(mod, inp, doup)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol)
tester(M.Sigmoid, (2, 3, 4, 5))
tester(M.ReLU, (2, 3,))
tester(M.LeakyReLU, (4, 5))
tester(M.LeakyReLU, (4, 5), negative_slope=0.3)
tester(M.PReLU, (8, 6, 5))
tester(M.PReLU, (8, 6, 5, 7), num_parameters=6, init=0.1)
tester(M.PReLU, (1,))
tester(M.SiLU, (4, 8, 3, 2))
tester(M.SiLU, (1, 1,))
tester(M.GELU, (1, 1, 2))
...@@ -564,6 +564,7 @@ REGISTE_PARAM_JSON_FUNC(LayerNormBackward) ...@@ -564,6 +564,7 @@ REGISTE_PARAM_JSON_FUNC(LayerNormBackward)
REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward) REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward)
REGISTE_PARAM_JSON_FUNC(DropoutBackward) REGISTE_PARAM_JSON_FUNC(DropoutBackward)
REGISTE_PARAM_JSON_FUNC(SoftmaxBackward) REGISTE_PARAM_JSON_FUNC(SoftmaxBackward)
REGISTE_PARAM_JSON_FUNC(ArgsortBackward)
std::shared_ptr<json::Value> dimshuffle_param2json( std::shared_ptr<json::Value> dimshuffle_param2json(
const opr::Dimshuffle::Param& param) { const opr::Dimshuffle::Param& param) {
...@@ -862,6 +863,7 @@ void OprFootprint::init_all_footprints() { ...@@ -862,6 +863,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::AdaptivePoolingBackward>(); add_single_param_json<opr::AdaptivePoolingBackward>();
add_single_param_json<opr::DropoutBackward>(); add_single_param_json<opr::DropoutBackward>();
add_single_param_json<opr::SoftmaxBackward>(); add_single_param_json<opr::SoftmaxBackward>();
add_single_param_json<opr::ArgsortBackward>();
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册