提交 97b1b777 编写于 作者: M Megvii Engine Team

feat(mgb): add megbrain layer norm opr with subgraph

GitOrigin-RevId: 9b7fa821f8e456329723aef3cf0b2a8080de669b
上级 eca6e1d9
......@@ -19,6 +19,7 @@ from ..core.ops.builtin import (
GetVarShape,
Identity,
Reduce,
Reshape,
TypeCvt,
)
from ..core.ops.special import Const
......@@ -1022,6 +1023,92 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down
@lru_cache(maxsize=None)
def _get_layerNorm(device, dtype, dim, gopt_level=2):
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level)
def layerNormAffine(inputs, f, c):
inp, eps, _flatten_shape, weight, bias = inputs
inp_shape = f(GetVarShape(), inp)
inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
affine_oup = f(Reshape(), oup, inp_shape)
affine_oup = f("fma3", affine_oup, weight, bias)
# NOTE: return oup make backward faster but take more memory
return (affine_oup, oup, mean, x2s), (True, False, False, False)
@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level)
def layerNorm(inputs, f, c):
inp, eps, _flatten_shape = inputs
inp_shape = f(GetVarShape(), inp)
inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
oup = f(Reshape(), oup, inp_shape)
return (oup,), (True,)
return (layerNorm, layerNormAffine)
def layer_norm(
inp: Tensor,
normalized_shape: tuple,
affine: bool,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
eps_mode="additive",
):
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode
)
_device = inp.device
_dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape)
_flatten_shape = concat(
(
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device),
convert_single_value(-1, dtype="int32", device=inp.device),
)
)
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim)
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
if affine:
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias)
else:
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape)
return outvar
def batch_norm(
inp: Tensor,
running_mean: Tensor = None,
......
......@@ -132,18 +132,9 @@ class LayerNorm(Module):
zeros_(self.bias)
def forward(self, x):
x_shape = x.shape
dim_delta = len(x_shape) - len(self.normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)
mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(x_shape)
if self.affine:
x = self.weight * x + self.bias
x = F.nn.layer_norm(
x, self.normalized_shape, self.affine, self.weight, self.bias, self.eps
)
return x
def _module_info_string(self) -> str:
......
......@@ -24,6 +24,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple
from megengine.device import get_device_count
from megengine.module import LayerNorm
def test_where():
......@@ -862,6 +863,61 @@ def test_conv1d():
)
def test_layer_norm():
def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5):
__layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine)
__layer_norm.weight = weight
__layer_norm.bias = bias
return __layer_norm(x)
def _layer_norm_numpy(
x, normalized_shape, affine, weight=None, bias=None, eps=1e-5
):
x_shape = x.shape
dim_delta = len(x_shape) - len(normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)
mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + eps)
x = x.reshape(x_shape)
if affine:
x = weight * x + bias
return x
normalized_shape = (28, 28)
inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32")
weight = Tensor(np.random.randn(28, 28), dtype="float32")
bias = Tensor(np.random.randn(28, 28), dtype="float32")
inp_feat = inp_feat + 1
weight = weight + 1
bias = bias
affine = False
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias)
assert abs(outvar - targetvar).mean() < 1e-7
# no random, affine True
normalized_shape = (28, 28)
inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32")
weight = Tensor(np.ones((28, 28)), dtype="float32")
bias = Tensor(np.zeros((28, 28)), dtype="float32")
affine = True
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias)
assert abs((outvar - targetvar).mean()) < 1e-7
assert abs(outvar.mean()) < 1e-7
def test_batchnorm2d_io16c32():
amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册