diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 0cbb32eca7b4a4adb3e06a31d58f97337b1e3e4f..0053a3825e1ad270f81ad36be821f56cc8568013 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -13,19 +13,23 @@ import numbers from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply +from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.special import Const from ..core.tensor import utils from ..tensor import Tensor +from .debug_param import get_conv_execution_strategy from .elemwise import clip, exp, log, log1p -from .tensor import reshape, squeeze +from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze __all__ = [ "argmax", "argmin", "argsort", + "dot", "isinf", "isnan", + "matmul", "max", "mean", "min", @@ -36,6 +40,7 @@ __all__ = [ "sort", "std", "sum", + "svd", "topk", "var", ] @@ -663,7 +668,7 @@ def topk( no_sort: bool = False, ) -> Tuple[Tensor, Tensor]: r""" - Selects the ``Top-K``(by default) smallest elements of 2d matrix by row. + Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row. :param inp: input tensor. If input tensor is 2d, each row will be sorted. :param k: number of elements needed. @@ -722,3 +727,204 @@ def topk( if descending: tns = -tns return tns, ind + + +def matmul( + inp1: Tensor, + inp2: Tensor, + transpose_a=False, + transpose_b=False, + compute_mode="DEFAULT", + format="DEFAULT", +) -> Tensor: + """ + Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. + + With different inputs dim, this function behaves differently: + + - Both 1-D tensor, simply forward to ``dot``. + - Both 2-D tensor, normal matrix multiplication. + - If one input tensor is 1-D, matrix vector multiplication. + - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example: + + - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` + - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` + - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` + + :param inp1: first matrix to be multiplied. + :param inp2: second matrix to be multiplied. + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) + data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) + out = F.matmul(data1, data2) + print(out.numpy()) + + Outputs: + + .. testoutput:: + + [[10. 13.] + [28. 40.]] + + """ + remove_row, remove_col = False, False + inp1, inp2 = utils.convert_inputs(inp1, inp2) + + dim1, dim2 = inp1.ndim, inp2.ndim + # handle dim=1 cases, dot and matrix-vector multiplication + if dim1 == 1 and dim2 == 1: + return dot(inp1, inp2) + # the underlying matmul op requires input dims to be at least 2 + if dim1 == 1: + inp1 = expand_dims(inp1, 0) + dim1 = 2 + remove_row = True + if dim2 == 1: + inp2 = expand_dims(inp2, 1) + dim2 = 2 + remove_col = True + + batch_shape = None + shape1 = inp1.shape + shape2 = inp2.shape + + maxdim = dim1 if dim1 > dim2 else dim2 + if dim1 >= 3 or dim2 >= 3: + if use_symbolic_shape(): + if dim1 > dim2: + shape2 = concat([shape1[:-2], shape2[-2:]]) + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = concat([shape2[:-2], shape1[-2:]]) + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + (inp1,) = apply( + builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) + ) + (inp2,) = apply( + builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) + ) + else: + if dim1 > dim2: + shape2 = shape1[:-2] + shape2[-2:] + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = shape2[:-2] + shape1[-2:] + inp1 = broadcast_to(inp1, shape1) + if maxdim > 3: + batch_shape = shape1[:-2] + # compress inputs to 3d + inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) + inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) + + op = builtin.BatchedMatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=get_conv_execution_strategy(), + ) + else: + op = builtin.MatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=get_conv_execution_strategy(), + ) + + (result,) = apply(op, inp1, inp2) + if maxdim > 3: + if use_symbolic_shape(): + (result,) = apply( + builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) + ) + else: + result = result.reshape(batch_shape + result.shape[-2:]) + if remove_row: + result = squeeze(result, axis=-2) + if remove_col: + result = squeeze(result, axis=-1) + return result + + +def dot(inp1: Tensor, inp2: Tensor) -> Tensor: + """ + Computes dot-product of two vectors ``inp1`` and ``inp2``. + inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. + Refer to :func:`~.matmul` for more general usage. + + :param inp1: first vector. + :param inp2: second vector. + :return: output value. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + data1 = tensor(np.arange(0, 6, dtype=np.float32)) + data2 = tensor(np.arange(0, 6, dtype=np.float32)) + out = F.dot(data1, data2) + print(out.numpy()) + + Outputs: + + .. testoutput:: + + 55. + + """ + op = builtin.Dot() + inp1, inp2 = utils.convert_inputs(inp1, inp2) + assert ( + inp1.ndim <= 1 and inp2.ndim <= 1 + ), "Input tensors for dot must be 1-dimensional or scalar" + (result,) = apply(op, inp1, inp2) + utils.setscalar(result) + return result + + +def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: + """ + Computes the singular value decompositions of input matrix. + + :param inp: input matrix, must has shape `[..., M, N]`. + :return: output matrices, `(U, sigma, V)`. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) + _, y, _ = F.svd(x) + print(y.numpy().round(decimals=3)) + + Outputs: + + .. testoutput:: + + [7.348 1. ] + + """ + op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) + U, sigma, V = apply(op, inp) + return U, sigma, V diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 8a06c3ad9098cc1547d91dd65a83c7edec496e8f..66c2d173fd043e909c8e229ea1afb9185aaa1a43 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -25,7 +25,7 @@ from ..utils.tuple_function import _pair, _pair_nonzero from .debug_param import get_conv_execution_strategy from .distributed import all_reduce_sum from .elemwise import exp, floor, log, log1p, maximum, minimum, relu -from .math import argsort, max, prod, sum +from .math import argsort, matmul, max, prod, sum from .tensor import ( broadcast_to, concat, @@ -46,7 +46,6 @@ __all__ = [ "conv_transpose2d", "deformable_conv2d", "deformable_psroi_pooling", - "dot", "dropout", "indexing_one_hot", "leaky_relu", @@ -55,7 +54,6 @@ __all__ = [ "logsumexp", "logsoftmax", "matinv", - "matmul", "max_pool2d", "one_hot", "prelu", @@ -63,7 +61,6 @@ __all__ = [ "resize", "softmax", "softplus", - "svd", "warp_affine", "warp_perspective", "conv1d", @@ -1221,207 +1218,6 @@ def matinv(inp: Tensor) -> Tensor: return result -def matmul( - inp1: Tensor, - inp2: Tensor, - transpose_a=False, - transpose_b=False, - compute_mode="DEFAULT", - format="DEFAULT", -) -> Tensor: - """ - Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. - - With different inputs dim, this function behaves differently: - - - Both 1-D tensor, simply forward to ``dot``. - - Both 2-D tensor, normal matrix multiplication. - - If one input tensor is 1-D, matrix vector multiplication. - - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will - be broadcasted. For example: - - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` - - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` - - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` - - :param inp1: first matrix to be multiplied. - :param inp2: second matrix to be multiplied. - :return: output tensor. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) - data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) - out = F.matmul(data1, data2) - print(out.numpy()) - - Outputs: - - .. testoutput:: - - [[10. 13.] - [28. 40.]] - - """ - remove_row, remove_col = False, False - inp1, inp2 = utils.convert_inputs(inp1, inp2) - - dim1, dim2 = inp1.ndim, inp2.ndim - # handle dim=1 cases, dot and matrix-vector multiplication - if dim1 == 1 and dim2 == 1: - return dot(inp1, inp2) - # the underlying matmul op requires input dims to be at least 2 - if dim1 == 1: - inp1 = expand_dims(inp1, 0) - dim1 = 2 - remove_row = True - if dim2 == 1: - inp2 = expand_dims(inp2, 1) - dim2 = 2 - remove_col = True - - batch_shape = None - shape1 = inp1.shape - shape2 = inp2.shape - - maxdim = dim1 if dim1 > dim2 else dim2 - if dim1 >= 3 or dim2 >= 3: - if use_symbolic_shape(): - if dim1 > dim2: - shape2 = concat([shape1[:-2], shape2[-2:]]) - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = concat([shape2[:-2], shape1[-2:]]) - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - (inp1,) = apply( - builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) - ) - (inp2,) = apply( - builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) - ) - else: - if dim1 > dim2: - shape2 = shape1[:-2] + shape2[-2:] - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = shape2[:-2] + shape1[-2:] - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) - inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) - - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=get_conv_execution_strategy(), - ) - else: - op = builtin.MatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=get_conv_execution_strategy(), - ) - - (result,) = apply(op, inp1, inp2) - if maxdim > 3: - if use_symbolic_shape(): - (result,) = apply( - builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) - ) - else: - result = result.reshape(batch_shape + result.shape[-2:]) - if remove_row: - result = squeeze(result, axis=-2) - if remove_col: - result = squeeze(result, axis=-1) - return result - - -def dot(inp1: Tensor, inp2: Tensor) -> Tensor: - """ - Computes dot-product of two vectors ``inp1`` and ``inp2``. - inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. - Refer to :func:`~.matmul` for more general usage. - - :param inp1: first vector. - :param inp2: second vector. - :return: output value. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - data1 = tensor(np.arange(0, 6, dtype=np.float32)) - data2 = tensor(np.arange(0, 6, dtype=np.float32)) - out = F.dot(data1, data2) - print(out.numpy()) - - Outputs: - - .. testoutput:: - - 55. - - """ - op = builtin.Dot() - inp1, inp2 = utils.convert_inputs(inp1, inp2) - assert ( - inp1.ndim <= 1 and inp2.ndim <= 1 - ), "Input tensors for dot must be 1-dimensional or scalar" - (result,) = apply(op, inp1, inp2) - setscalar(result) - return result - - -def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: - """ - Computes the singular value decompositions of input matrix. - - :param inp: input matrix, must has shape `[..., M, N]`. - :return: output matrices, `(U, sigma, V)`. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) - _, y, _ = F.svd(x) - print(y.numpy().round(decimals=3)) - - Outputs: - - .. testoutput:: - - [7.348 1. ] - - """ - op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) - U, sigma, V = apply(op, inp) - return U, sigma, V - - def interpolate( inp: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 0266da1336153325fe66ff44b8e0e2b4f7feab8a..487863de0969a03141fb1dd382c4991098931bf4 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -707,7 +707,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: :param inp: input tensor. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, - and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: + and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: * (``'x'``) -> make a 0d (scalar) into a 1d vector * (0, 1) -> identity for 2d vectors