提交 df3474ca 编写于 作者: M Megvii Engine Team

perf(functional): rewrite serval elemwise ops with jit subgraph

GitOrigin-RevId: 26247e21d9300ffa368c0eef4c76d6c502b684e5
上级 c55fda9a
......@@ -36,6 +36,7 @@ from ..core.tensor.utils import (
convert_single_value,
make_shape_tuple,
subgraph,
subgraph_fn,
)
from ..device import get_default_device
from ..distributed import WORLD, is_distributed
......@@ -824,9 +825,37 @@ def sigmoid(x):
return _elwise(x, mode=Elemwise.Mode.SIGMOID)
@lru_cache(maxsize=None)
def _get_hsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"Hsigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def hsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
inp = f("+", inp, c(3))
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = f("/", min_6, c(6))
(oup_grad,) = yield (oup,)
inp_grad = f("/", oup_grad, c(6))
inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return hsigmoid
def hsigmoid(x):
r"""Element-wise `relu6(x + 3) / 6`."""
return relu6(x + 3) / 6
hsigmoid = _get_hsigmoid_op(x.dtype, x.device)
(x,) = hsigmoid(x)
return x
# return relu6(x + 3) / 6
def relu(x):
......@@ -834,9 +863,60 @@ def relu(x):
return _elwise(x, mode=Elemwise.Mode.RELU)
@lru_cache(maxsize=None)
def _get_relu6_op(dtype=None, device=None):
@subgraph_fn(
"ReLU6",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def relu6(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = min_6
(oup_grad,) = yield (oup,)
inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return relu6
def relu6(x):
r"""Element-wise `min(max(x, 0), 6)`."""
return minimum(maximum(x, 0), 6)
relu6 = _get_relu6_op(x.dtype, x.device)
(x,) = relu6(x)
return x
@lru_cache(maxsize=None)
def _get_prelu_op(dtype=None, device=None):
@subgraph_fn(
"PReLU",
dtype=dtype,
device=device,
nr_inputs=2,
jit_fusion=True,
custom_grad=True,
)
def prelu(inputs, f, c):
(inp, weight) = inputs[0:2]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)
return prelu
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
......@@ -844,7 +924,34 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor:
Refer to :class:`~.PReLU` for more information.
"""
return maximum(inp, 0) + weight * minimum(inp, 0)
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, weight)
return oup
@lru_cache(maxsize=None)
def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None):
@subgraph_fn(
"LeakyReLU",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def leakyReLU(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("+", max_0, f("*", min_0, c(negative_slope)))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return leakyReLU
def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
......@@ -852,7 +959,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
Refer to :class:`~.LeakyReLU` for more information.
"""
return maximum(inp, 0) + negative_slope * minimum(inp, 0)
leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
(oup,) = leakyReLU(inp)
return oup
def silu(x):
......@@ -871,6 +980,36 @@ def gelu(x):
return _elwise(x, mode=Elemwise.Mode.GELU)
@lru_cache(maxsize=None)
def _get_softplus_op(dtype=None, device=None):
@subgraph_fn(
"Softplus",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
# gopt_level=0,
custom_grad=True,
)
def softplus(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", inp))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return softplus
def softplus(inp: Tensor) -> Tensor:
r"""Applies the element-wise function:
......@@ -904,7 +1043,9 @@ def softplus(inp: Tensor) -> Tensor:
[0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
"""
return log1p(exp(-abs(inp))) + relu(inp)
softplus = _get_softplus_op(inp.dtype, inp.device)
(oup,) = softplus(inp)
return oup
def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
......@@ -944,6 +1085,38 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return inp - logsumexp(inp, axis, keepdims=True)
@lru_cache(maxsize=None)
def _get_logsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"LogSigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def logsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", f("-", inp)))
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return logsigmoid
def logsigmoid(inp: Tensor) -> Tensor:
r"""Applies the element-wise function:
......@@ -972,7 +1145,9 @@ def logsigmoid(inp: Tensor) -> Tensor:
[-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
-0.0181]
"""
return -softplus(-inp)
logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
(oup,) = logsigmoid(inp)
return oup
def logsumexp(
......
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import lru_cache
from typing import Iterable, Optional, Sequence, Tuple, Union
import numpy as np
......@@ -17,7 +18,14 @@ from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
get_device,
isscalar,
setscalar,
subgraph_fn,
)
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
......@@ -731,6 +739,29 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
return inp
@lru_cache(maxsize=None)
def _get_where_op(dtype=None, device=None):
@subgraph_fn(
"Where",
dtype=dtype,
device=device,
nr_inputs=3,
jit_fusion=True,
custom_grad=True,
)
def where(inputs, f, c):
(mask, x, y) = inputs[0:3]
oup = f("switch_gt0", mask, x)
ksam = f("-", c(1), mask)
oup = f("+", oup, f("switch_gt0", ksam, y))
(oup_grad,) = yield (oup,)
x_grad = f("switch_gt0", mask, oup_grad)
y_grad = f("switch_gt0", ksam, oup_grad)
yield (None, x_grad, y_grad)
return where
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
r"""Selects elements either from Tensor x or Tensor y, according to mask.
......@@ -780,20 +811,19 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
dtype = dtype_promotion(x, y)
device = x.device
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
y = y.astype(dtype)
mask = mask.astype(dtype)
v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y)
out = concat([v0, v1])
out[index0] = v0
out[index1] = v1
out = out.reshape(x.shape)
return out
where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y)
if isscalar(mask):
setscalar(oup)
return oup
def cond_take(mask: Tensor, x: Tensor) -> Tensor:
......
......@@ -166,7 +166,7 @@ def test_hsigmoid():
x = np.random.randn(100).astype("float32")
y_np = np.minimum(np.maximum(x + 3, 0), 6) / 6
y_mge = F.hsigmoid(tensor(x)).numpy()
np.testing.assert_equal(y_np, y_mge)
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
def test_logical_oprs():
......
......@@ -27,6 +27,8 @@ from megengine.core.tensor.utils import make_shape_tuple
from megengine.device import get_device_count
from megengine.module import LayerNorm
_assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
def test_where():
maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
......@@ -627,6 +629,7 @@ def test_binary_cross_entropy():
{"input": [data1, label1], "output": expect1,},
{"input": [data2, label2], "output": expect2,},
]
opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
cases = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册