提交 fccb2510 编写于 作者: M Megvii Engine Team

perf(mge): add opdef for elemwise and batchnorm

GitOrigin-RevId: d51fad98678098651e83ebdd2a3c1ff6196cb7a2
上级 c008cf37
...@@ -30,7 +30,6 @@ from ..tensor.core import apply ...@@ -30,7 +30,6 @@ from ..tensor.core import apply
from ..tensor.function import Function from ..tensor.function import Function
from ..tensor.tensor_wrapper import TensorWrapper from ..tensor.tensor_wrapper import TensorWrapper
_elemwise_add_param = Elemwise(mode="add").to_c().param
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] _reduce_sum_param = Reduce(mode="SUM").to_c().param[0]
...@@ -44,12 +43,12 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): ...@@ -44,12 +43,12 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
if isinstance(op, OprAttr): if isinstance(op, OprAttr):
grad_fn = _oprAttr_grad_fn.get(op.type, None) grad_fn = _oprAttr_grad_fn.get(op.type, None)
if grad_fn is None: if grad_fn is None:
if op.type == Elemwise.name and op.param == _elemwise_add_param: if op.type == Reduce.name and op.param[0] == _reduce_sum_param:
grad_fn = elemwise_add_grad_fn
elif op.type == Reduce.name and op.param[0] == _reduce_sum_param:
grad_fn = reduce_sum_grad_fn grad_fn = reduce_sum_grad_fn
else: else:
grad_fn = default_grad_fn grad_fn = default_grad_fn
elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD:
grad_fn = elemwise_add_grad_fn
else: else:
grad_fn = default_grad_fn grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad) return grad_fn(op, inputs, outputs, input_requires_grad)
...@@ -158,11 +157,8 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -158,11 +157,8 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
params = inputs[1:] params = inputs[1:]
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
TensorWrapper(0, dtype=dy.dtype, device=dy.device) (grad,) = apply(Broadcast(), _z, input_shape)
._broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)
return dx return dx
...@@ -184,11 +180,8 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -184,11 +180,8 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
params = inputs[1:] params = inputs[1:]
def make_grad(grad_op, dy): def make_grad(grad_op, dy):
grad = ( (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy)
TensorWrapper(0, dtype=dy.dtype, device=dy.device) (grad,) = apply(Broadcast(), _z, input_shape)
._broadcast(TensorWrapper(input_shape))
.__wrapped__
)
(dx,) = apply(grad_op, grad, dy, *params) (dx,) = apply(grad_op, grad, dy, *params)
return dx return dx
......
...@@ -47,7 +47,7 @@ def get_grad_managers(): ...@@ -47,7 +47,7 @@ def get_grad_managers():
def add(a, b): def add(a, b):
(c,) = apply(Elemwise(mode="add"), a, b) (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b)
return c return c
......
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
from .._trace_option import use_symbolic_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import GetVarShape from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const from ..ops.special import Const
from . import utils from . import utils
from .core import OpBase, TensorBase, TensorWrapperBase, apply from .core import OpBase, TensorBase, TensorWrapperBase, apply
...@@ -23,10 +23,12 @@ from .raw_tensor import RawTensor, as_raw_tensor ...@@ -23,10 +23,12 @@ from .raw_tensor import RawTensor, as_raw_tensor
from .tensor import Tensor from .tensor import Tensor
from .utils import make_shape_tuple as _make_shape_tuple from .utils import make_shape_tuple as _make_shape_tuple
_ElwMod = Elemwise.Mode
def _elwise(*args, mode): def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode) op = builtin.Elemwise(mode)
if mode in ("TRUE_DIV", "POW"): if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW):
args = tuple( args = tuple(
map( map(
lambda x: x.astype("float32") lambda x: x.astype("float32")
...@@ -272,53 +274,53 @@ class ArrayMethodMixin(abc.ABC): ...@@ -272,53 +274,53 @@ class ArrayMethodMixin(abc.ABC):
__hash__ = None # due to __eq__ diviates from python convention __hash__ = None # due to __eq__ diviates from python convention
__lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool")
__le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool")
__gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool")
__ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool")
__eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool")
__ne__ = lambda self, value: _elwise( __ne__ = lambda self, value: _elwise(
_elwise(self, value, mode="EQ").astype("bool"), mode="NOT" _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT,
) )
__neg__ = _unary_elwise("NEGATE") __neg__ = _unary_elwise(_ElwMod.NEGATE)
__pos__ = lambda self: self __pos__ = lambda self: self
__abs__ = _unary_elwise("ABS") __abs__ = _unary_elwise(_ElwMod.ABS)
__invert__ = _logical_unary_elwise("NOT") __invert__ = _logical_unary_elwise(_ElwMod.NOT)
__round__ = _unary_elwise("ROUND") __round__ = _unary_elwise(_ElwMod.ROUND)
__trunc__ = _todo __trunc__ = _todo
__floor__ = _unary_elwise("FLOOR") __floor__ = _unary_elwise(_ElwMod.FLOOR)
__ceil__ = _unary_elwise("CEIL") __ceil__ = _unary_elwise(_ElwMod.CEIL)
__add__ = _binary_elwise("ADD") __add__ = _binary_elwise(_ElwMod.ADD)
__sub__ = _binary_elwise("SUB") __sub__ = _binary_elwise(_ElwMod.SUB)
__mul__ = _binary_elwise("MUL") __mul__ = _binary_elwise(_ElwMod.MUL)
__matmul__ = lambda self, other: _matmul(self, other) __matmul__ = lambda self, other: _matmul(self, other)
__truediv__ = _binary_elwise("TRUE_DIV") __truediv__ = _binary_elwise(_ElwMod.TRUE_DIV)
__floordiv__ = _binary_elwise("FLOOR_DIV") __floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV)
__mod__ = _binary_elwise("MOD") __mod__ = _binary_elwise(_ElwMod.MOD)
# __divmode__ # __divmode__
__pow__ = _binary_elwise("POW") __pow__ = _binary_elwise(_ElwMod.POW)
__lshift__ = _binary_elwise("SHL") __lshift__ = _binary_elwise(_ElwMod.SHL)
__rshift__ = _binary_elwise("SHR") __rshift__ = _binary_elwise(_ElwMod.SHR)
__and__ = _logical_binary_elwise("AND") __and__ = _logical_binary_elwise(_ElwMod.AND)
__or__ = _logical_binary_elwise("OR") __or__ = _logical_binary_elwise(_ElwMod.OR)
__xor__ = _logical_binary_elwise("XOR") __xor__ = _logical_binary_elwise(_ElwMod.XOR)
__radd__ = _binary_elwise("ADD", rev=1) __radd__ = _binary_elwise(_ElwMod.ADD, rev=1)
__rsub__ = _binary_elwise("SUB", rev=1) __rsub__ = _binary_elwise(_ElwMod.SUB, rev=1)
__rmul__ = _binary_elwise("MUL", rev=1) __rmul__ = _binary_elwise(_ElwMod.MUL, rev=1)
__rmatmul__ = lambda self, other: _matmul(other, self) __rmatmul__ = lambda self, other: _matmul(other, self)
__rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) __rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1)
__rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) __rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1)
__rmod__ = _binary_elwise("MOD", rev=1) __rmod__ = _binary_elwise(_ElwMod.MOD, rev=1)
# __rdivmode__ # __rdivmode__
__rpow__ = _binary_elwise("POW", rev=1) __rpow__ = _binary_elwise(_ElwMod.POW, rev=1)
__rlshift__ = _binary_elwise("SHL", rev=1) __rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1)
__rrshift__ = _binary_elwise("SHR", rev=1) __rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1)
__rand__ = _logical_binary_elwise("AND", rev=1) __rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1)
__ror__ = _logical_binary_elwise("OR", rev=1) __ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1)
__rxor__ = _logical_binary_elwise("XOR", rev=1) __rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1)
__iadd__ = _inplace(__add__) __iadd__ = _inplace(__add__)
__isub__ = _inplace(__sub__) __isub__ = _inplace(__sub__)
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
import functools import functools
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..device import get_default_device from ..device import get_default_device
...@@ -72,7 +73,7 @@ __all__ = [ ...@@ -72,7 +73,7 @@ __all__ = [
def _elwise(*args, mode): def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode) op = builtin.Elemwise(mode)
tensor_args = list( tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
) )
...@@ -128,67 +129,67 @@ def add(x, y): ...@@ -128,67 +129,67 @@ def add(x, y):
[ 6. 8. 10.]] [ 6. 8. 10.]]
""" """
return _elwise(x, y, mode="add") return _elwise(x, y, mode=Elemwise.Mode.ADD)
def sub(x, y): def sub(x, y):
"""Element-wise `subtraction`.""" """Element-wise `subtraction`."""
return _elwise(x, y, mode="sub") return _elwise(x, y, mode=Elemwise.Mode.SUB)
def mul(x, y): def mul(x, y):
"""Element-wise `multiplication`.""" """Element-wise `multiplication`."""
return _elwise(x, y, mode="mul") return _elwise(x, y, mode=Elemwise.Mode.MUL)
def div(x, y): def div(x, y):
"""Element-wise `(x / y)`.""" """Element-wise `(x / y)`."""
return _elwise(x, y, mode="true_div") return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
def floor_div(x, y): def floor_div(x, y):
"""Element-wise `floor(x / y)`.""" """Element-wise `floor(x / y)`."""
return _elwise(x, y, mode="floor_divide") return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE)
def neg(x): def neg(x):
"""Element-wise `negation`.""" """Element-wise `negation`."""
return _elwise(x, mode="negate") return _elwise(x, mode=Elemwise.Mode.NEGATE)
def pow(x, y): def pow(x, y):
"""Element-wise `power`.""" """Element-wise `power`."""
return _elwise(x, y, mode="pow") return _elwise(x, y, mode=Elemwise.Mode.POW)
def mod(x, y): def mod(x, y):
"""Element-wise `remainder of division`.""" """Element-wise `remainder of division`."""
return _elwise(x, y, mode="mod") return _elwise(x, y, mode=Elemwise.Mode.MOD)
def abs(x): def abs(x):
"""Element-wise `absolute value`.""" """Element-wise `absolute value`."""
return _elwise(x, mode="abs") return _elwise(x, mode=Elemwise.Mode.ABS)
def exp(x): def exp(x):
"""Element-wise `exponential`.""" """Element-wise `exponential`."""
return _elwise(x, mode="exp") return _elwise(x, mode=Elemwise.Mode.EXP)
def expm1(x): def expm1(x):
"""Element-wise `exp(x)-1`.""" """Element-wise `exp(x)-1`."""
return _elwise(x, mode="expm1") return _elwise(x, mode=Elemwise.Mode.EXPM1)
def log(x): def log(x):
"""Element-wise `logarithm (base e)`.""" """Element-wise `logarithm (base e)`."""
return _elwise(x, mode="log") return _elwise(x, mode=Elemwise.Mode.LOG)
def log1p(x): def log1p(x):
"""Element-wise `log(x+1) (base e)`.""" """Element-wise `log(x+1) (base e)`."""
return _elwise(x, mode="log1p") return _elwise(x, mode=Elemwise.Mode.LOG1P)
def sqrt(x: Tensor) -> Tensor: def sqrt(x: Tensor) -> Tensor:
...@@ -253,27 +254,27 @@ def square(x: Tensor) -> Tensor: ...@@ -253,27 +254,27 @@ def square(x: Tensor) -> Tensor:
def round(x): def round(x):
"""Element-wise `rounding to int`.""" """Element-wise `rounding to int`."""
return _elwise(x, mode="round") return _elwise(x, mode=Elemwise.Mode.ROUND)
def ceil(x): def ceil(x):
"""Element-wise `ceiling`.""" """Element-wise `ceiling`."""
return _elwise(x, mode="ceil") return _elwise(x, mode=Elemwise.Mode.CEIL)
def floor(x): def floor(x):
"""Element-wise `floor`.""" """Element-wise `floor`."""
return _elwise(x, mode="floor") return _elwise(x, mode=Elemwise.Mode.FLOOR)
def maximum(x, y): def maximum(x, y):
"""Element-wise `maximum of array elements`.""" """Element-wise `maximum of array elements`."""
return _elwise(x, y, mode="max") return _elwise(x, y, mode=Elemwise.Mode.MAX)
def minimum(x, y): def minimum(x, y):
"""Element-wise `minimum of array elements`.""" """Element-wise `minimum of array elements`."""
return _elwise(x, y, mode="min") return _elwise(x, y, mode=Elemwise.Mode.MIN)
# trigonometric functions # trigonometric functions
...@@ -305,12 +306,12 @@ def cos(x): ...@@ -305,12 +306,12 @@ def cos(x):
[-0.99 -0.6536 0.2837]] [-0.99 -0.6536 0.2837]]
""" """
return _elwise(x, mode="cos") return _elwise(x, mode=Elemwise.Mode.COS)
def sin(x): def sin(x):
"""Element-wise `sine`.""" """Element-wise `sine`."""
return _elwise(x, mode="sin") return _elwise(x, mode=Elemwise.Mode.SIN)
def tan(x): def tan(x):
...@@ -320,22 +321,22 @@ def tan(x): ...@@ -320,22 +321,22 @@ def tan(x):
def acos(x): def acos(x):
"""Element-wise `inverse cosine`.""" """Element-wise `inverse cosine`."""
return _elwise(x, mode="acos") return _elwise(x, mode=Elemwise.Mode.ACOS)
def asin(x): def asin(x):
"""Element-wise `inverse sine`.""" """Element-wise `inverse sine`."""
return _elwise(x, mode="asin") return _elwise(x, mode=Elemwise.Mode.ASIN)
def atan(x): def atan(x):
"""Element-wise `inverse tangent`.""" """Element-wise `inverse tangent`."""
return _elwise(x, 1, mode="atan2") return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
def atan2(y, x): def atan2(y, x):
"""Element-wise `2-argument arctangent`.""" """Element-wise `2-argument arctangent`."""
return _elwise(y, x, mode="atan2") return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
def cosh(x): def cosh(x):
...@@ -351,7 +352,7 @@ def sinh(x): ...@@ -351,7 +352,7 @@ def sinh(x):
def tanh(x): def tanh(x):
r"""Element-wise `hyperbolic tangent`.""" r"""Element-wise `hyperbolic tangent`."""
return _elwise(x, mode="tanh") return _elwise(x, mode=Elemwise.Mode.TANH)
def asinh(x): def asinh(x):
...@@ -399,12 +400,12 @@ def left_shift(x, y): ...@@ -399,12 +400,12 @@ def left_shift(x, y):
[12 16 20]] [12 16 20]]
""" """
return _elwise(x, y, mode="shl") return _elwise(x, y, mode=Elemwise.Mode.SHL)
def right_shift(x, y): def right_shift(x, y):
"""Element-wise `bitwise binary: x >> y`.""" """Element-wise `bitwise binary: x >> y`."""
return _elwise(x, y, mode="shr") return _elwise(x, y, mode=Elemwise.Mode.SHR)
# logical functions # logical functions
...@@ -412,22 +413,22 @@ def right_shift(x, y): ...@@ -412,22 +413,22 @@ def right_shift(x, y):
def logical_and(x, y): def logical_and(x, y):
"""Element-wise `logical and: x && y`.""" """Element-wise `logical and: x && y`."""
return _elwise(x, y, mode="AND") return _elwise(x, y, mode=Elemwise.Mode.AND)
def logical_not(x): def logical_not(x):
"""Element-wise `logical not: ~x`.""" """Element-wise `logical not: ~x`."""
return _elwise(x, mode="NOT") return _elwise(x, mode=Elemwise.Mode.NOT)
def logical_or(x, y): def logical_or(x, y):
"""Element-wise `logical or: x || y`.""" """Element-wise `logical or: x || y`."""
return _elwise(x, y, mode="OR") return _elwise(x, y, mode=Elemwise.Mode.OR)
def logical_xor(x, y): def logical_xor(x, y):
"""Element-wise `logical xor: x ^ y`.""" """Element-wise `logical xor: x ^ y`."""
return _elwise(x, y, mode="XOR") return _elwise(x, y, mode=Elemwise.Mode.XOR)
# comparison functions # comparison functions
...@@ -461,7 +462,7 @@ def equal(x, y): ...@@ -461,7 +462,7 @@ def equal(x, y):
[1. 1. 1.]] [1. 1. 1.]]
""" """
return _elwise(x, y, mode="eq") return _elwise(x, y, mode=Elemwise.Mode.EQ)
def not_equal(x, y): def not_equal(x, y):
...@@ -471,22 +472,22 @@ def not_equal(x, y): ...@@ -471,22 +472,22 @@ def not_equal(x, y):
def less(x, y): def less(x, y):
"""Element-wise `(x < y)`.""" """Element-wise `(x < y)`."""
return _elwise(x, y, mode="lt") return _elwise(x, y, mode=Elemwise.Mode.LT)
def less_equal(x, y): def less_equal(x, y):
"""Element-wise `(x <= y)`.""" """Element-wise `(x <= y)`."""
return _elwise(x, y, mode="leq") return _elwise(x, y, mode=Elemwise.Mode.LEQ)
def greater(x, y): def greater(x, y):
"""Element-wise `(x > y)`.""" """Element-wise `(x > y)`."""
return _elwise(y, x, mode="lt") return _elwise(y, x, mode=Elemwise.Mode.LT)
def greater_equal(x, y): def greater_equal(x, y):
"""Element-wise `(x >= y)`.""" """Element-wise `(x >= y)`."""
return _elwise(y, x, mode="leq") return _elwise(y, x, mode=Elemwise.Mode.LEQ)
# other functions # other functions
...@@ -515,7 +516,7 @@ def hswish(x): ...@@ -515,7 +516,7 @@ def hswish(x):
[0. 0.6667 1.6667 3. 4. ] [0. 0.6667 1.6667 3. 4. ]
""" """
return _elwise(x, mode="h_swish") return _elwise(x, mode=Elemwise.Mode.H_SWISH)
def hsigmoid(x): def hsigmoid(x):
...@@ -525,7 +526,7 @@ def hsigmoid(x): ...@@ -525,7 +526,7 @@ def hsigmoid(x):
def relu(x): def relu(x):
"""Element-wise `max(x, 0)`.""" """Element-wise `max(x, 0)`."""
return _elwise(x, mode="relu") return _elwise(x, mode=Elemwise.Mode.RELU)
def relu6(x): def relu6(x):
...@@ -535,7 +536,7 @@ def relu6(x): ...@@ -535,7 +536,7 @@ def relu6(x):
def sigmoid(x): def sigmoid(x):
"""Element-wise `1 / ( 1 + exp( -x ) )`.""" """Element-wise `1 / ( 1 + exp( -x ) )`."""
return _elwise(x, mode="sigmoid") return _elwise(x, mode=Elemwise.Mode.SIGMOID)
def clip(x: Tensor, lower=None, upper=None) -> Tensor: def clip(x: Tensor, lower=None, upper=None) -> Tensor:
......
...@@ -12,6 +12,7 @@ from typing import Optional, Sequence, Tuple, Union ...@@ -12,6 +12,7 @@ from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
...@@ -643,19 +644,22 @@ def batch_norm( ...@@ -643,19 +644,22 @@ def batch_norm(
if inp.ndim != 4: if inp.ndim != 4:
raise NotImplementedError("batch_norm for ndim != 4") raise NotImplementedError("batch_norm for ndim != 4")
def full_value(value): C = inp.shape[1]
C = inp.shape[1]
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
return broadcast_to(x, [1, C, 1, 1])
def expand_or_full(x, value):
if x is None:
return full_value(value)
return expand_dims(x, [0, 2, 3])
def make_full_if_none(x, value): def make_full_if_none(x, value):
if x is None: if x is None:
return full(shape=(1, inp.shape[1], 1, 1), value=value) (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
elif x.ndim == 1:
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
(result,) = apply(builtin.Reshape(), x, shape)
return result
return x return x
has_mean = running_mean is not None has_mean = running_mean is not None
...@@ -674,19 +678,25 @@ def batch_norm( ...@@ -674,19 +678,25 @@ def batch_norm(
inp, weight, bias, running_mean, running_var inp, weight, bias, running_mean, running_var
) )
weight = expand_or_full(weight, 1) weight = make_full_if_none(weight, 1)
bias = expand_or_full(bias, 0) bias = make_full_if_none(bias, 0)
if not training: if not training:
op = builtin.BatchNorm(fwd_mode="INFERENCE", epsilon=eps, param_dim="DIM_1C11") op = builtin.BatchNorm(
BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret return ret
else: else:
op = builtin.BatchNorm( op = builtin.BatchNorm(
avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" BatchNorm.ParamDim.DIM_1C11,
BatchNorm.FwdMode.TRAINING,
eps,
1.0 - momentum,
1.0,
0.0,
) )
if has_mean or has_var: if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0) running_mean = make_full_if_none(running_mean, 0)
running_var = make_full_if_none(running_var, 1) running_var = make_full_if_none(running_var, 1)
...@@ -708,7 +718,7 @@ def batch_norm( ...@@ -708,7 +718,7 @@ def batch_norm(
else: else:
return inp, new_mean, new_var return inp, new_mean, new_var
else: else:
_, _, inp, = apply(op, inp, weight, bias) (_, _, inp,) = apply(op, inp, weight, bias)
return inp return inp
......
...@@ -72,14 +72,15 @@ class _BatchNorm(Module): ...@@ -72,14 +72,15 @@ class _BatchNorm(Module):
self.track_running_stats == False self.track_running_stats == False
), "track_running_stats can not be initilized to False and changed to True later" ), "track_running_stats can not be initilized to False and changed to True later"
_ndims = len(inp.shape) inp_shape = inp.shape
_ndims = len(inp_shape)
if _ndims != 4: if _ndims != 4:
origin_shape = inp.shape origin_shape = inp_shape
if _ndims == 2: if _ndims == 2:
n, c = inp.shape[0], inp.shape[1] n, c = inp_shape[0], inp_shape[1]
new_shape = (n, c, 1, 1) new_shape = (n, c, 1, 1)
elif _ndims == 3: elif _ndims == 3:
n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] n, c, h = inp_shape[0], inp_shape[1], inp_shape[2]
new_shape = (n, c, h, 1) new_shape = (n, c, h, 1)
inp = inp.reshape(new_shape) inp = inp.reshape(new_shape)
...@@ -150,17 +151,18 @@ class SyncBatchNorm(_BatchNorm): ...@@ -150,17 +151,18 @@ class SyncBatchNorm(_BatchNorm):
def forward(self, inp): def forward(self, inp):
self._check_input_ndim(inp) self._check_input_ndim(inp)
_ndims = len(inp.shape) inp_shape = inp.shape
_ndims = len(inp_shape)
if _ndims != 4: if _ndims != 4:
new_shape = Tensor([1, 1, 1, 1], device=inp.device) new_shape = Tensor([1, 1, 1, 1], device=inp.device)
origin_shape = inp.shape origin_shape = inp_shape
if _ndims == 2: if _ndims == 2:
new_shape[:2] = origin_shape[:2] new_shape[:2] = origin_shape[:2]
elif _ndims == 3: elif _ndims == 3:
new_shape[:3] = origin_shape[:3] new_shape[:3] = origin_shape[:3]
else: else:
raise ValueError( raise ValueError(
"expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
) )
inp = inp.reshape(new_shape) inp = inp.reshape(new_shape)
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "megbrain/imperative/ops/io_remote.h" #include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/cond_take.h" #include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/nms.h" #include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -117,4 +119,91 @@ void init_ops(py::module m) { ...@@ -117,4 +119,91 @@ void init_ops(py::module m) {
.def(py::init<float, uint32_t>()) .def(py::init<float, uint32_t>())
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh) .def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
.def_readwrite("max_output", &NMSKeep::max_output); .def_readwrite("max_output", &NMSKeep::max_output);
py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise");
elemwise.def(py::init<Elemwise::Mode>())
.def_readwrite("mode", &Elemwise::mode);
#define V(m) .value(#m, Elemwise::Mode::m)
py::enum_<Elemwise::Mode>(elemwise, "Mode")
V(RELU)
V(ABS)
V(ACOS)
V(ASIN)
V(CEIL)
V(COS)
V(EXP)
V(EXPM1)
V(FLOOR)
V(LOG)
V(LOG1P)
V(NEGATE)
V(SIGMOID)
V(SIN)
V(TANH)
V(ABS_GRAD)
V(ADD)
V(FLOOR_DIV)
V(MAX)
V(MIN)
V(MOD)
V(MUL)
V(POW)
V(SIGMOID_GRAD)
V(SUB)
V(SWITCH_GT0)
V(TANH_GRAD)
V(TRUE_DIV)
V(LOG_SUM_EXP)
V(LT)
V(LEQ)
V(EQ)
V(SHL)
V(SHR)
V(COND_LEQ_MOV)
V(FUSE_MUL_ADD3)
V(FUSE_MUL_ADD4)
V(FUSE_ADD_RELU)
V(FUSE_ADD_SIGMOID)
V(FUSE_ADD_TANH)
V(FAST_TANH)
V(FAST_TANH_GRAD)
V(ROUND)
V(RMULH)
V(ATAN2)
V(ERF)
V(ERFINV)
V(ERFC)
V(ERFCINV)
V(H_SWISH)
V(H_SWISH_GRAD)
V(FUSE_ADD_H_SWISH)
V(NOT)
V(AND)
V(OR)
V(XOR);
#undef V
py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm");
batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>())
.def_readwrite("param_dim", &BatchNorm::param_dim)
.def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
.def_readwrite("epsilon", &BatchNorm::epsilon)
.def_readwrite("avg_factor", &BatchNorm::avg_factor)
.def_readwrite("scale", &BatchNorm::scale)
.def_readwrite("bias", &BatchNorm::bias);
#define V(m) .value(#m, BatchNorm::Param::ParamDim::m)
py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim")
V(DIM_11HW)
V(DIM_1CHW)
V(DIM_1C11);
#undef V
#define V(m) .value(#m, BatchNorm::Param::FwdMode::m)
py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode")
V(TRAINING)
V(INFERENCE);
#undef V
} }
...@@ -27,7 +27,7 @@ from megengine.functional.distributed import remote_recv, remote_send ...@@ -27,7 +27,7 @@ from megengine.functional.distributed import remote_recv, remote_send
def _elwise(mode): def _elwise(mode):
op = Elemwise(mode=mode) op = Elemwise(mode)
def f(*args): def f(*args):
(result,) = apply(op, *args) (result,) = apply(op, *args)
...@@ -36,10 +36,10 @@ def _elwise(mode): ...@@ -36,10 +36,10 @@ def _elwise(mode):
return f return f
add = _elwise("add") add = _elwise(Elemwise.Mode.ADD)
mul = _elwise("mul") mul = _elwise(Elemwise.Mode.MUL)
cos = _elwise("cos") cos = _elwise(Elemwise.Mode.COS)
relu = _elwise("relu") relu = _elwise(Elemwise.Mode.RELU)
def as_tensor(x): def as_tensor(x):
...@@ -255,7 +255,7 @@ def test_elemwise_relu(): ...@@ -255,7 +255,7 @@ def test_elemwise_relu():
def test_elemwise_relu_backward_fn(): def test_elemwise_relu_backward_fn():
op = Elemwise(mode="relu").to_c() op = Elemwise(Elemwise.Mode.RELU)
attr = TensorAttr() attr = TensorAttr()
attr.dtype = "float32" attr.dtype = "float32"
attr.comp_node = "xpux" attr.comp_node = "xpux"
......
...@@ -17,7 +17,7 @@ def elemwise(*args, mode): ...@@ -17,7 +17,7 @@ def elemwise(*args, mode):
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core._imperative_rt.imperative import apply_op from megengine.core._imperative_rt.imperative import apply_op
return apply_op(Elemwise(mode=mode).to_c(), args) return apply_op(Elemwise(mode), args)
def test_basic_interface(): def test_basic_interface():
...@@ -37,13 +37,15 @@ def test_basic_interface(): ...@@ -37,13 +37,15 @@ def test_basic_interface():
def test_opr_attr(): def test_opr_attr():
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
assert Elemwise(mode="add") == Elemwise(mode="add") assert Elemwise(Elemwise.Mode.ADD) == Elemwise(Elemwise.Mode.ADD)
def test_simple_arith(): def test_simple_arith():
from megengine.core.ops.builtin import Elemwise
x = np.random.rand(10).astype("float32") x = np.random.rand(10).astype("float32")
xx = megengine.core._imperative_rt.put(x) xx = megengine.core._imperative_rt.put(x)
(yy,) = elemwise(xx, xx, mode="mul") (yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL)
np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy))
megengine.core._imperative_rt.delete(xx) megengine.core._imperative_rt.delete(xx)
megengine.core._imperative_rt.delete(yy) megengine.core._imperative_rt.delete(yy)
...@@ -64,7 +66,7 @@ def test_raw_tensor(): ...@@ -64,7 +66,7 @@ def test_raw_tensor():
x = np.random.rand(10).astype("float32") x = np.random.rand(10).astype("float32")
xx = as_raw_tensor(x) xx = as_raw_tensor(x)
(yy,) = apply(Elemwise(mode="mul"), xx, xx) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy()) np.testing.assert_allclose(x * x, yy.numpy())
(yy,) = apply(Elemwise(mode="mul"), xx, xx) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy()) np.testing.assert_allclose(x * x, yy.numpy())
...@@ -17,6 +17,7 @@ import megengine.functional as F ...@@ -17,6 +17,7 @@ import megengine.functional as F
from megengine import cgtools, tensor from megengine import cgtools, tensor
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.functional import exp, log from megengine.functional import exp, log
...@@ -28,7 +29,7 @@ def test_trace(): ...@@ -28,7 +29,7 @@ def test_trace():
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
op = ops.Elemwise(mode="negate") op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x) (y,) = apply(op, x)
return y return y
...@@ -44,7 +45,7 @@ def test_exclude_from_trace(): ...@@ -44,7 +45,7 @@ def test_exclude_from_trace():
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
neg = ops.Elemwise(mode="negate") neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x) (x,) = apply(neg, x)
with exclude_from_trace(): with exclude_from_trace():
if i % 2: if i % 2:
...@@ -65,7 +66,7 @@ def test_print_in_trace(): ...@@ -65,7 +66,7 @@ def test_print_in_trace():
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
nonlocal buf nonlocal buf
neg = ops.Elemwise(mode="negate") neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x) (x,) = apply(neg, x)
buf = x.numpy() buf = x.numpy()
(x,) = apply(neg, x) (x,) = apply(neg, x)
...@@ -85,7 +86,7 @@ def test_print_in_trace(): ...@@ -85,7 +86,7 @@ def test_print_in_trace():
def test_dump(): def test_dump():
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(a, b): def f(a, b):
op = ops.Elemwise(mode="add") op = ops.Elemwise(Elemwise.Mode.ADD)
(y,) = apply(op, a, b) (y,) = apply(op, a, b)
return y return y
...@@ -111,7 +112,7 @@ def test_capture_dump(): ...@@ -111,7 +112,7 @@ def test_capture_dump():
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(x): def f(x):
op = ops.Elemwise(mode="mul") op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, a) (y,) = apply(op, x, a)
return y return y
...@@ -133,7 +134,7 @@ def test_dump_volatile(): ...@@ -133,7 +134,7 @@ def test_dump_volatile():
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def f(x): def f(x):
op = ops.Elemwise(mode="mul") op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, p) (y,) = apply(op, x, p)
return y return y
...@@ -159,7 +160,7 @@ def test_trace_profiler(): ...@@ -159,7 +160,7 @@ def test_trace_profiler():
@trace(symbolic=symbolic, profiling=True) @trace(symbolic=symbolic, profiling=True)
def f(x): def f(x):
op = ops.Elemwise(mode="negate") op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x) (y,) = apply(op, x)
return y return y
......
/**
* \file imperative/src/impl/ops/batch_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/batch_norm.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::BatchNorm>();
auto&& param = node->param();
return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon,
param.avg_factor, param.scale, param.bias);
}
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& bn_opr = def.cast_final_safe<BatchNorm>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
if (nr_inp == 3) {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2],
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
.node()->owner_opr();
} else {
return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
.node()->owner_opr();
}
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<BatchNorm>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
// need running mean/variance
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING;
size_t nr_out = need_stat? 5 : 3;
SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0];
auto&& i1 = inputs[1];
size_t i = 0;
if (!need_stat) {
out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node};
i = 2;
}
for (; i < nr_out-1; ++ i) {
out_shapes[i] = {i1.layout, i1.comp_node};
}
out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
return out_shapes;
}
OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm);
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/impl/ops/elemwise.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/elemwise.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Elemwise>();
return Elemwise::make(node->param().mode);
}
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
}
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size());
TensorShapeArray inp_shapes;
DType out_dt;
CompNode out_cn;
for (size_t i = 0; i < inputs.size(); ++ i) {
auto &&t = inputs[i];
if (!i) {
out_cn = t.comp_node;
out_dt = t.layout.dtype;
} else {
mgb_assert(t.comp_node == out_cn);
mgb_assert(t.layout.dtype == out_dt);
}
if (t.layout.ndim > 0) {
inp_shapes.push_back(t.layout);
} else {
TensorLayout out_layout;
out_layout.ndim = 0;
out_layout.dtype = out_dt;
return {{out_layout, out_cn}};
}
}
auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}};
}
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise);
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/ops/batch_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class BatchNorm : public OpDefImplBase<BatchNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Param = opr::BatchNorm::Param;
Param::ParamDim param_dim;
Param::FwdMode fwd_mode;
double epsilon;
double avg_factor;
float scale;
float bias;
BatchNorm() = default;
BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_,
double epsilon_, double avg_factor_, float scale_, float bias_)
: param_dim(param_dim_),
fwd_mode(fwd_mode_),
epsilon(epsilon_),
avg_factor(avg_factor_),
scale(scale_),
bias(bias_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(param_dim);
append(fwd_mode);
append(epsilon);
append(avg_factor);
append(scale);
append(bias);
return xxhash.digest();
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const BatchNorm&>(rhs_);
return rhs.param_dim == param_dim
&& rhs.fwd_mode == fwd_mode
&& rhs.epsilon == epsilon
&& rhs.avg_factor == avg_factor
&& rhs.scale == scale
&& rhs.bias == bias;
}
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/elemwise.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
class Elemwise : public OpDefImplBase<Elemwise> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = opr::Elemwise::Mode;
using ModeTrait = megdnn::Elemwise::ModeTrait;
Mode mode;
Elemwise() = default;
Elemwise(const Mode& mode_): mode(mode_) {}
size_t hash() const override {
return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const Elemwise&>(rhs_);
return rhs.mode == mode;
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册