提交 98d6ba7e 编写于 作者: M Megvii Engine Team

refactor(mge/functional): move matinv from nn to math

GitOrigin-RevId: 75c48ce32729ea82fbc86e2f7dc0427709c2e978
上级 17323dbd
......@@ -29,6 +29,7 @@ __all__ = [
"dot",
"isinf",
"isnan",
"matinv",
"matmul",
"max",
"mean",
......@@ -729,6 +730,38 @@ def topk(
return tns, ind
def matinv(inp: Tensor) -> Tensor:
"""
Computes the inverse of a batch of matrices; input must has shape [..., n, n].
:param inp: input tensor.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data = tensor([[1.0, 0.0], [1.0, 1.0]])
out = F.matinv(data)
print(out.numpy())
Outputs:
.. testoutput::
[[ 1. 0.]
[-1. 1.]]
"""
(result,) = apply(builtin.MatrixInverse(), inp)
return result
def matmul(
inp1: Tensor,
inp2: Tensor,
......
......@@ -53,7 +53,6 @@ __all__ = [
"logsigmoid",
"logsumexp",
"logsoftmax",
"matinv",
"max_pool2d",
"one_hot",
"prelu",
......@@ -1183,38 +1182,6 @@ def remap(
return result
def matinv(inp: Tensor) -> Tensor:
"""
Computes the inverse of a batch of matrices; input must has shape [..., n, n].
:param inp: input tensor.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data = tensor([[1.0, 0.0], [1.0, 1.0]])
out = F.matinv(data)
print(out.numpy())
Outputs:
.. testoutput::
[[ 1. 0.]
[-1. 1.]]
"""
(result,) = apply(builtin.MatrixInverse(), inp)
return result
def interpolate(
inp: Tensor,
size: Optional[Union[int, Tuple[int, int]]] = None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册