提交 5507a29b 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/functional): add normalize opr

GitOrigin-RevId: 572a32f5633bfd775e5cc2209e4c9005f49e3761
上级 000663c3
......@@ -48,7 +48,7 @@ from .loss import (
square_loss,
triplet_margin_loss,
)
from .math import argmax, argmin, max, mean, min, norm, prod, sqrt, sum
from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum
from .nn import (
assert_equal,
avg_pool2d,
......
......@@ -11,6 +11,7 @@ from typing import Optional
import megengine._internal as mgb
from ..core import Tensor, wrap_io_tensor
from .elemwise import clamp
@wrap_io_tensor
......@@ -199,8 +200,7 @@ def sqrt(inp: Tensor) -> Tensor:
return mgb.opr.sqrt(inp)
@wrap_io_tensor
def norm(inp: Tensor, p=2, axis: Optional[int] = None, keepdims=False):
def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False):
"""Calculate ``p``-norm of input tensor along certain axis.
:param inp: The input tensor
......@@ -271,3 +271,28 @@ def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> T
"""
return mgb.opr.argmax(inp, axis, keepdims)
def normalize(
inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12
) -> Tensor:
r"""Perform :math:`L_p` normalization of input tensor along certain axis.
For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
:param inp: the input tensor
:param p: power of value ``p`` applied to ``inp``. Default: 2
:param axis: The dimension to reduce. If None, all the dimensions will be reduced
to calculate the norm. Default: None
:param eps: a small value to avoid division by zero. Default: 1e-12
:return: the normalized output tensor
"""
if axis is None:
return inp / clamp(norm(inp, p), lower=eps)
else:
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
......@@ -83,3 +83,32 @@ def test_sqrt():
cases = [{"input": d1}, {"input": d2}]
opr_test(cases, F.sqrt, ref_fn=np.sqrt)
def test_normalize():
from functools import partial
cases = [
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
]
def np_normalize(x, p=2, axis=None, eps=1e-12):
if axis is None:
norm = np.sum(x ** p) ** (1.0 / p)
else:
norm = np.sum(x ** p, axis=axis, keepdims=True) ** (1.0 / p)
return x / np.clip(norm, a_min=eps, a_max=np.inf)
# Test L-2 norm along all dimensions
opr_test(cases, F.normalize, ref_fn=np_normalize)
# Test L-1 norm along all dimensions
opr_test(cases, partial(F.normalize, p=1), ref_fn=partial(np_normalize, p=1))
# Test L-2 norm along the second dimension
opr_test(cases, partial(F.normalize, axis=1), ref_fn=partial(np_normalize, axis=1))
# Test some norm == 0
cases[0]["input"][0, 0, 0, :] = 0
cases[1]["input"][0, 0, 0, :] = 0
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册