提交 df976782 编写于 作者: M Megvii Engine Team

feat(mge/functional): add matinv

GitOrigin-RevId: d4fa8a82778abec33d2791f0a867d12482fad51a
上级 3bda3347
......@@ -52,6 +52,7 @@ __all__ = [
"logsigmoid",
"logsumexp",
"logsoftmax",
"matinv",
"matmul",
"max_pool2d",
"one_hot",
......@@ -1002,6 +1003,38 @@ def remap(
return result
def matinv(inp: Tensor) -> Tensor:
"""
Computes the inverse of a batch of matrices; input must has shape [..., n, n].
:param inp: input tensor.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data = tensor([[1.0, 0.0], [1.0, 1.0]])
out = F.matinv(data)
print(out.numpy())
Outputs:
.. testoutput::
[[ 1. 0.]
[-1. 1.]]
"""
(result,) = apply(builtin.MatrixInverse(), inp)
return result
def matmul(
inp1: Tensor,
inp2: Tensor,
......
......@@ -60,6 +60,25 @@ def test_dropout():
assert out.numpy().sum() >= 0.0
def test_matinv():
shape1 = (5, 5)
shape2 = (3, 9, 9)
data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32")
cases = [
{"input": data1},
{"input": data2},
]
opr_test(
cases,
F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
ref_fn=np.linalg.inv,
)
def test_matmul():
shape1 = 3
shape2 = 3
......
/**
* \file imperative/src/impl/ops/matrix_inverse.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"
namespace mgb{
namespace imperative {
namespace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 1);
return opr::MatrixInverse::make(inputs[0]);
}
OP_TRAIT_REG(MatrixInverse, MatrixInverse)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // anonymous namespace
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -34,6 +34,8 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let results = (outs AnyType);
}
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册