From df3474ca1d3e29b0ffa7787fd3ec5b80532b10e4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 19:56:45 +0800 Subject: [PATCH] perf(functional): rewrite serval elemwise ops with jit subgraph GitOrigin-RevId: 26247e21d9300ffa368c0eef4c76d6c502b684e5 --- imperative/python/megengine/functional/nn.py | 187 +++++++++++++++++- .../python/megengine/functional/tensor.py | 50 ++++- .../test/unit/functional/test_elemwise.py | 2 +- .../test/unit/functional/test_functional.py | 3 + 4 files changed, 225 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f498460f3..42b108eb7 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -36,6 +36,7 @@ from ..core.tensor.utils import ( convert_single_value, make_shape_tuple, subgraph, + subgraph_fn, ) from ..device import get_default_device from ..distributed import WORLD, is_distributed @@ -824,9 +825,37 @@ def sigmoid(x): return _elwise(x, mode=Elemwise.Mode.SIGMOID) +@lru_cache(maxsize=None) +def _get_hsigmoid_op(dtype=None, device=None): + @subgraph_fn( + "Hsigmoid", + dtype=dtype, + device=device, + nr_inputs=1, + jit_fusion=True, + custom_grad=True, + ) + def hsigmoid(inputs, f, c): + (inp,) = inputs[0:1] + inp = f("+", inp, c(3)) + max_0 = f("max", inp, c(0)) + min_6 = f("min", max_0, c(6)) + oup = f("/", min_6, c(6)) + (oup_grad,) = yield (oup,) + inp_grad = f("/", oup_grad, c(6)) + inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad) + inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) + yield (inp_grad,) + + return hsigmoid + + def hsigmoid(x): r"""Element-wise `relu6(x + 3) / 6`.""" - return relu6(x + 3) / 6 + hsigmoid = _get_hsigmoid_op(x.dtype, x.device) + (x,) = hsigmoid(x) + return x + # return relu6(x + 3) / 6 def relu(x): @@ -834,9 +863,60 @@ def relu(x): return _elwise(x, mode=Elemwise.Mode.RELU) +@lru_cache(maxsize=None) +def _get_relu6_op(dtype=None, device=None): + @subgraph_fn( + "ReLU6", + dtype=dtype, + device=device, + nr_inputs=1, + jit_fusion=True, + custom_grad=True, + ) + def relu6(inputs, f, c): + (inp,) = inputs[0:1] + max_0 = f("max", inp, c(0)) + min_6 = f("min", max_0, c(6)) + oup = min_6 + (oup_grad,) = yield (oup,) + inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad) + inp_grad = f("cond_leq_mov", c(0), inp, inp_grad) + yield (inp_grad,) + + return relu6 + + def relu6(x): r"""Element-wise `min(max(x, 0), 6)`.""" - return minimum(maximum(x, 0), 6) + relu6 = _get_relu6_op(x.dtype, x.device) + (x,) = relu6(x) + return x + + +@lru_cache(maxsize=None) +def _get_prelu_op(dtype=None, device=None): + @subgraph_fn( + "PReLU", + dtype=dtype, + device=device, + nr_inputs=2, + jit_fusion=True, + custom_grad=True, + ) + def prelu(inputs, f, c): + (inp, weight) = inputs[0:2] + max_0 = f("max", inp, c(0)) + min_0 = f("min", inp, c(0)) + oup = f("fma3", min_0, weight, max_0) + (oup_grad,) = yield (oup,) + inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad) + inp_grad_1 = f("*", oup_grad, weight) + inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1) + inp_grad = f("+", inp_grad_0, inp_grad_1) + weight_grad = f("*", oup_grad, min_0) + yield (inp_grad, weight_grad) + + return prelu def prelu(inp: Tensor, weight: Tensor) -> Tensor: @@ -844,7 +924,34 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: Refer to :class:`~.PReLU` for more information. """ - return maximum(inp, 0) + weight * minimum(inp, 0) + prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device) + (oup,) = prelu(inp, weight) + return oup + + +@lru_cache(maxsize=None) +def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): + @subgraph_fn( + "LeakyReLU", + dtype=dtype, + device=device, + nr_inputs=1, + jit_fusion=True, + custom_grad=True, + ) + def leakyReLU(inputs, f, c): + (inp,) = inputs[0:1] + max_0 = f("max", inp, c(0)) + min_0 = f("min", inp, c(0)) + oup = f("+", max_0, f("*", min_0, c(negative_slope))) + (oup_grad,) = yield (oup,) + inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) + inp_grad_1 = f("*", oup_grad, c(negative_slope)) + inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1) + inp_grad = f("+", inp_grad_0, inp_grad_1) + yield (inp_grad,) + + return leakyReLU def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: @@ -852,7 +959,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: Refer to :class:`~.LeakyReLU` for more information. """ - return maximum(inp, 0) + negative_slope * minimum(inp, 0) + leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device) + (oup,) = leakyReLU(inp) + return oup def silu(x): @@ -871,6 +980,36 @@ def gelu(x): return _elwise(x, mode=Elemwise.Mode.GELU) +@lru_cache(maxsize=None) +def _get_softplus_op(dtype=None, device=None): + @subgraph_fn( + "Softplus", + dtype=dtype, + device=device, + nr_inputs=1, + jit_fusion=True, + # gopt_level=0, + custom_grad=True, + ) + def softplus(inputs, f, c): + (inp,) = inputs[0:1] + neg_abs = f("-", f("abs", inp)) + exp = f("exp", neg_abs) + oup = f("log1p", exp) + oup = f("+", oup, f("relu", inp)) + (oup_grad,) = yield (oup,) + inp_grad_0 = f("switch_gt0", inp, oup_grad) + inp_grad_1 = oup_grad + inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) + inp_grad_1 = f("*", oup_grad, exp) + inp_grad_1 = f("-", inp_grad_1) + inp_grad_1 = f("abs_grad", inp, inp_grad_1) + inp_grad = f("+", inp_grad_0, inp_grad_1) + yield (inp_grad,) + + return softplus + + def softplus(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -904,7 +1043,9 @@ def softplus(inp: Tensor) -> Tensor: [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269] """ - return log1p(exp(-abs(inp))) + relu(inp) + softplus = _get_softplus_op(inp.dtype, inp.device) + (oup,) = softplus(inp) + return oup def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: @@ -944,6 +1085,38 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: return inp - logsumexp(inp, axis, keepdims=True) +@lru_cache(maxsize=None) +def _get_logsigmoid_op(dtype=None, device=None): + @subgraph_fn( + "LogSigmoid", + dtype=dtype, + device=device, + nr_inputs=1, + jit_fusion=True, + custom_grad=True, + ) + def logsigmoid(inputs, f, c): + (inp,) = inputs[0:1] + neg_abs = f("-", f("abs", inp)) + exp = f("exp", neg_abs) + oup = f("log1p", exp) + oup = f("+", oup, f("relu", f("-", inp))) + oup = f("-", oup) + (oup_grad,) = yield (oup,) + oup_grad = f("-", oup_grad) + inp_grad_0 = f("switch_gt0", inp, oup_grad) + inp_grad_0 = f("-", inp_grad_0) + inp_grad_1 = oup_grad + inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) + inp_grad_1 = f("*", oup_grad, exp) + inp_grad_1 = f("-", inp_grad_1) + inp_grad_1 = f("abs_grad", inp, inp_grad_1) + inp_grad = f("+", inp_grad_0, inp_grad_1) + yield (inp_grad,) + + return logsigmoid + + def logsigmoid(inp: Tensor) -> Tensor: r"""Applies the element-wise function: @@ -972,7 +1145,9 @@ def logsigmoid(inp: Tensor) -> Tensor: [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486 -0.0181] """ - return -softplus(-inp) + logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device) + (oup,) = logsigmoid(inp) + return oup def logsumexp( diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index adc781b03..529ba499e 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from functools import lru_cache from typing import Iterable, Optional, Sequence, Tuple, Union import numpy as np @@ -17,7 +18,14 @@ from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const from ..core.tensor.array_method import _broadcast, _remove_axis -from ..core.tensor.utils import astensor1d, convert_inputs, get_device +from ..core.tensor.utils import ( + astensor1d, + convert_inputs, + get_device, + isscalar, + setscalar, + subgraph_fn, +) from ..device import get_default_device from ..tensor import Tensor from .elemwise import ceil @@ -731,6 +739,29 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: return inp +@lru_cache(maxsize=None) +def _get_where_op(dtype=None, device=None): + @subgraph_fn( + "Where", + dtype=dtype, + device=device, + nr_inputs=3, + jit_fusion=True, + custom_grad=True, + ) + def where(inputs, f, c): + (mask, x, y) = inputs[0:3] + oup = f("switch_gt0", mask, x) + ksam = f("-", c(1), mask) + oup = f("+", oup, f("switch_gt0", ksam, y)) + (oup_grad,) = yield (oup,) + x_grad = f("switch_gt0", mask, oup_grad) + y_grad = f("switch_gt0", ksam, oup_grad) + yield (None, x_grad, y_grad) + + return where + + def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: r"""Selects elements either from Tensor x or Tensor y, according to mask. @@ -780,20 +811,19 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) dtype = dtype_promotion(x, y) + device = x.device + if x.dtype != dtype: x = x.astype(dtype) if y.dtype != dtype: y = y.astype(dtype) + mask = mask.astype(dtype) - v0, index0 = cond_take(mask, x) - v1, index1 = cond_take(~mask, y) - - out = concat([v0, v1]) - - out[index0] = v0 - out[index1] = v1 - out = out.reshape(x.shape) - return out + where = _get_where_op(dtype=dtype, device=device) + (oup,) = where(mask, x, y) + if isscalar(mask): + setscalar(oup) + return oup def cond_take(mask: Tensor, x: Tensor) -> Tensor: diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index f4c1788b5..d2fbaf08a 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -166,7 +166,7 @@ def test_hsigmoid(): x = np.random.randn(100).astype("float32") y_np = np.minimum(np.maximum(x + 3, 0), 6) / 6 y_mge = F.hsigmoid(tensor(x)).numpy() - np.testing.assert_equal(y_np, y_mge) + np.testing.assert_almost_equal(y_np, y_mge, decimal=6) def test_logical_oprs(): diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 2a96c1d09..ec1ff2fab 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -27,6 +27,8 @@ from megengine.core.tensor.utils import make_shape_tuple from megengine.device import get_device_count from megengine.module import LayerNorm +_assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) + def test_where(): maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_) @@ -627,6 +629,7 @@ def test_binary_cross_entropy(): {"input": [data1, label1], "output": expect1,}, {"input": [data2, label2], "output": expect2,}, ] + opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn) cases = [ -- GitLab