提交 05692d05 编写于 作者: M Megvii Engine Team

refactor(functional): move matmul dot svd to math pkg

GitOrigin-RevId: 15eb08bacb5f047b9fe869b4672d03ae45eaa9c5
上级 02df634d
...@@ -13,19 +13,23 @@ import numbers ...@@ -13,19 +13,23 @@ import numbers
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import utils
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy
from .elemwise import clip, exp, log, log1p from .elemwise import clip, exp, log, log1p
from .tensor import reshape, squeeze from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze
__all__ = [ __all__ = [
"argmax", "argmax",
"argmin", "argmin",
"argsort", "argsort",
"dot",
"isinf", "isinf",
"isnan", "isnan",
"matmul",
"max", "max",
"mean", "mean",
"min", "min",
...@@ -36,6 +40,7 @@ __all__ = [ ...@@ -36,6 +40,7 @@ __all__ = [
"sort", "sort",
"std", "std",
"sum", "sum",
"svd",
"topk", "topk",
"var", "var",
] ]
...@@ -663,7 +668,7 @@ def topk( ...@@ -663,7 +668,7 @@ def topk(
no_sort: bool = False, no_sort: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
r""" r"""
Selects the ``Top-K``(by default) smallest elements of 2d matrix by row. Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
:param inp: input tensor. If input tensor is 2d, each row will be sorted. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
:param k: number of elements needed. :param k: number of elements needed.
...@@ -722,3 +727,204 @@ def topk( ...@@ -722,3 +727,204 @@ def topk(
if descending: if descending:
tns = -tns tns = -tns
return tns, ind return tns, ind
def matmul(
inp1: Tensor,
inp2: Tensor,
transpose_a=False,
transpose_b=False,
compute_mode="DEFAULT",
format="DEFAULT",
) -> Tensor:
"""
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
With different inputs dim, this function behaves differently:
- Both 1-D tensor, simply forward to ``dot``.
- Both 2-D tensor, normal matrix multiplication.
- If one input tensor is 1-D, matrix vector multiplication.
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example:
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
:param inp1: first matrix to be multiplied.
:param inp2: second matrix to be multiplied.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
out = F.matmul(data1, data2)
print(out.numpy())
Outputs:
.. testoutput::
[[10. 13.]
[28. 40.]]
"""
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = inp1.shape
shape2 = inp2.shape
maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
else:
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
(result,) = apply(op, inp1, inp2)
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.
:param inp1: first vector.
:param inp2: second vector.
:return: output value.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32))
data2 = tensor(np.arange(0, 6, dtype=np.float32))
out = F.dot(data1, data2)
print(out.numpy())
Outputs:
.. testoutput::
55.
"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
utils.setscalar(result)
return result
def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
"""
Computes the singular value decompositions of input matrix.
:param inp: input matrix, must has shape `[..., M, N]`.
:return: output matrices, `(U, sigma, V)`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
_, y, _ = F.svd(x)
print(y.numpy().round(decimals=3))
Outputs:
.. testoutput::
[7.348 1. ]
"""
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
U, sigma, V = apply(op, inp)
return U, sigma, V
...@@ -25,7 +25,7 @@ from ..utils.tuple_function import _pair, _pair_nonzero ...@@ -25,7 +25,7 @@ from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_conv_execution_strategy from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, prod, sum from .math import argsort, matmul, max, prod, sum
from .tensor import ( from .tensor import (
broadcast_to, broadcast_to,
concat, concat,
...@@ -46,7 +46,6 @@ __all__ = [ ...@@ -46,7 +46,6 @@ __all__ = [
"conv_transpose2d", "conv_transpose2d",
"deformable_conv2d", "deformable_conv2d",
"deformable_psroi_pooling", "deformable_psroi_pooling",
"dot",
"dropout", "dropout",
"indexing_one_hot", "indexing_one_hot",
"leaky_relu", "leaky_relu",
...@@ -55,7 +54,6 @@ __all__ = [ ...@@ -55,7 +54,6 @@ __all__ = [
"logsumexp", "logsumexp",
"logsoftmax", "logsoftmax",
"matinv", "matinv",
"matmul",
"max_pool2d", "max_pool2d",
"one_hot", "one_hot",
"prelu", "prelu",
...@@ -63,7 +61,6 @@ __all__ = [ ...@@ -63,7 +61,6 @@ __all__ = [
"resize", "resize",
"softmax", "softmax",
"softplus", "softplus",
"svd",
"warp_affine", "warp_affine",
"warp_perspective", "warp_perspective",
"conv1d", "conv1d",
...@@ -1221,207 +1218,6 @@ def matinv(inp: Tensor) -> Tensor: ...@@ -1221,207 +1218,6 @@ def matinv(inp: Tensor) -> Tensor:
return result return result
def matmul(
inp1: Tensor,
inp2: Tensor,
transpose_a=False,
transpose_b=False,
compute_mode="DEFAULT",
format="DEFAULT",
) -> Tensor:
"""
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
With different inputs dim, this function behaves differently:
- Both 1-D tensor, simply forward to ``dot``.
- Both 2-D tensor, normal matrix multiplication.
- If one input tensor is 1-D, matrix vector multiplication.
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will
be broadcasted. For example:
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
:param inp1: first matrix to be multiplied.
:param inp2: second matrix to be multiplied.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
out = F.matmul(data1, data2)
print(out.numpy())
Outputs:
.. testoutput::
[[10. 13.]
[28. 40.]]
"""
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = inp1.shape
shape2 = inp2.shape
maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
else:
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
(result,) = apply(op, inp1, inp2)
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.
:param inp1: first vector.
:param inp2: second vector.
:return: output value.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32))
data2 = tensor(np.arange(0, 6, dtype=np.float32))
out = F.dot(data1, data2)
print(out.numpy())
Outputs:
.. testoutput::
55.
"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
setscalar(result)
return result
def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
"""
Computes the singular value decompositions of input matrix.
:param inp: input matrix, must has shape `[..., M, N]`.
:return: output matrices, `(U, sigma, V)`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
_, y, _ = F.svd(x)
print(y.numpy().round(decimals=3))
Outputs:
.. testoutput::
[7.348 1. ]
"""
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
U, sigma, V = apply(op, inp)
return U, sigma, V
def interpolate( def interpolate(
inp: Tensor, inp: Tensor,
size: Optional[Union[int, Tuple[int, int]]] = None, size: Optional[Union[int, Tuple[int, int]]] = None,
......
...@@ -707,7 +707,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: ...@@ -707,7 +707,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
:param inp: input tensor. :param inp: input tensor.
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, :param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
* (``'x'``) -> make a 0d (scalar) into a 1d vector * (``'x'``) -> make a 0d (scalar) into a 1d vector
* (0, 1) -> identity for 2d vectors * (0, 1) -> identity for 2d vectors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册