From ca2deebc0f122ae6735e7018eba386694e70785f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 11 Feb 2022 17:12:29 +0800 Subject: [PATCH] fix(imperative/tensor): make @ operator has the same functionality as matmul functional GitOrigin-RevId: bf6136cc1a7b5cc00103fd0eee27cb8bca8c6f99 --- .../megengine/core/tensor/array_method.py | 286 +++++++++++++++++- .../python/megengine/functional/math.py | 282 +---------------- .../test/unit/core/test_tensor_wrapper.py | 9 +- 3 files changed, 289 insertions(+), 288 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index a7c086c8d..818dd3783 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import abc import collections +from functools import lru_cache from typing import Union import numpy as np @@ -24,8 +25,8 @@ from .utils import ( astype, cast_tensors, convert_inputs, - isscalar, make_shape_tuple, + subgraph, ) _ElwMod = builtin.Elemwise.Mode @@ -73,23 +74,292 @@ def _elwise(*args, mode): return _elwise_apply(args, mode) -def _matmul(inp1, inp2): +@lru_cache(maxsize=None) +def _get_extentedMatrixMulOp( + device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, +): + @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) + def extentedMatrixMulOp(inputs, f, c): + assert len(inputs) == 2 + inp1, inp2 = inputs + _dim1, _dim2 = dim1, dim2 + + def build_shape_head(shape, idx=-1): + # shape[:idx] + return f( + builtin.Subtensor(items=[[0, False, True, False, False]]), + shape, + c(idx, "int32"), + ) + + def build_shape_tail(shape, idx=-1): + # shape[idx:] + return f( + builtin.Subtensor(items=[[0, True, False, False, False]]), + shape, + c(idx, "int32"), + ) + + remove_row, remove_col = False, False + if _dim1 == 1: + _dim1 = 2 + remove_row = True + if _dim2 == 1: + _dim2 = 2 + remove_col = True + + if remove_row: + inp1 = f(builtin.AddAxis(axis=[0,]), inp1) + if remove_col: + inp2 = f(builtin.AddAxis(axis=[1,]), inp2) + + shape1 = f(builtin.GetVarShape(), inp1) + shape2 = f(builtin.GetVarShape(), inp2) + if _dim1 > 2: + inp1 = f( + builtin.Reshape(), + inp1, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), + build_shape_tail(shape1), + ), + ) + if _dim2 > 2: + inp2 = f( + builtin.Reshape(), + inp2, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), + build_shape_tail(shape2), + ), + ) + op = builtin.MatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=strategy.value, + ) + result = f(op, inp1, inp2) + result_shape = f(builtin.GetVarShape(), result) + if _dim1 > 2: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape1), + build_shape_tail(result_shape), + ), + ) + if _dim2 > 2: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape2), + build_shape_tail(result_shape), + ), + ) + maxdim = _dim1 if _dim1 > _dim2 else _dim2 + if remove_row: + result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) + if remove_col: + result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) + return (result,), (True,) + + return extentedMatrixMulOp + + +@lru_cache(maxsize=None) +def _get_extentedBatchedMatrixMulOp( + device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, +): + @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) + def extentedBatchedMatrixMulOp(inputs, f, c): + assert len(inputs) == 2 + inp1, inp2 = inputs + _dim1, _dim2 = dim1, dim2 + + def build_shape_head(shape, idx=-2): + # shape[:idx] + return f( + builtin.Subtensor(items=[[0, False, True, False, False]]), + shape, + c(idx, "int32"), + ) + + def build_shape_tail(shape, idx=-2): + # shape[idx:] + return f( + builtin.Subtensor(items=[[0, True, False, False, False]]), + shape, + c(idx, "int32"), + ) + + remove_row, remove_col = False, False + if _dim1 == 1: + _dim1 = 2 + remove_row = True + if _dim2 == 1: + _dim2 = 2 + remove_col = True + + if remove_row: + inp1 = f(builtin.AddAxis(axis=[0,]), inp1) + if remove_col: + inp2 = f(builtin.AddAxis(axis=[1,]), inp2) + shape1 = f(builtin.GetVarShape(), inp1) + shape2 = f(builtin.GetVarShape(), inp2) + maxdim = _dim1 if _dim1 > _dim2 else _dim2 + if _dim1 > _dim2: + # broadcast + shape2 = f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] + shape2, + ) + inp2 = f(builtin.Broadcast(), inp2, shape2) + batch_shape = build_shape_head(shape1) + if _dim2 > _dim1: + # broadcast + shape1 = f( + builtin.Concat(axis=0, comp_node=device), + build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] + shape1, + ) + inp1 = f(builtin.Broadcast(), inp1, shape1) + batch_shape = build_shape_head(shape2) + if _dim1 == _dim2: + batch_shape = build_shape_head(shape1) + + # compress inputs to 3d + if maxdim > 3: + inp1 = f( + builtin.Reshape(), + inp1, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), batch_shape), + build_shape_tail(shape1), + ), + ) + inp2 = f( + builtin.Reshape(), + inp2, + f( + builtin.Concat(axis=0, comp_node=device), + f(builtin.Reduce(mode="product", axis=0), batch_shape), + build_shape_tail(shape2), + ), + ) + op = builtin.BatchedMatrixMul( + transposeA=transpose_a, + transposeB=transpose_b, + compute_mode=compute_mode, + format=format, + strategy=strategy.value, + ) + result = f(op, inp1, inp2) + + if maxdim > 3: + result = f( + builtin.Reshape(), + result, + f( + builtin.Concat(axis=0, comp_node=device), + batch_shape, + build_shape_tail(f(builtin.GetVarShape(), result)), + ), + ) + if remove_row: + result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) + if remove_col: + result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) + return (result,), (True,) + + return extentedBatchedMatrixMulOp + + +class _Hashable: + def __init__(self, value) -> None: + self.value = value + + def __hash__(self) -> int: + return hash(str(self.value)) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, _Hashable): + return False + return self.value == o.value + + +def _matmul( + inp1, + inp2, + transpose_a=False, + transpose_b=False, + compute_mode="default", + format="default", +): if amp._enabled: compute_mode = "float32" inp1, inp2 = cast_tensors(inp1, inp2) else: - compute_mode = "default" dtype = dtype_promotion(inp1, inp2) if inp1.dtype != dtype: inp1 = inp1.astype(dtype) if inp2.dtype != dtype: inp2 = inp2.astype(dtype) + + dim1, dim2 = inp1.ndim, inp2.ndim + assert dim1 > 0 and dim2 > 0 + maxdim = dim1 if dim1 > dim2 else dim2 compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) - op = builtin.MatrixMul( - transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" - ) - (result,) = apply(op, inp1, inp2) - return result + + Strategy = builtin.ops.MatrixMul.Strategy + strategy = Strategy(0) + if _config._benchmark_kernel: + strategy |= Strategy.PROFILE + else: + strategy |= Strategy.HEURISTIC + if _config._deterministic_kernel: + strategy |= Strategy.REPRODUCIBLE + + if dim1 == 1 and dim2 == 1: # dispatch to Dot + (result,) = apply(builtin.Dot(), inp1, inp2) + return result + elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul + extentedMatrixMulOp = _get_extentedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, + strategy=_Hashable(strategy), + ) + (result,) = apply(extentedMatrixMulOp(), inp1, inp2) + return result + else: # dispath to BatchedMatrixMul + extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( + inp1.device, + inp1.dtype, + dim1, + dim2, + transpose_a, + transpose_b, + compute_mode, + format, + strategy=_Hashable(strategy), + ) + (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) + return result def _transpose(data, axes): diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index b0fcc7ce9..f1cc41c70 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -8,24 +8,18 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import math -from functools import lru_cache from typing import Iterable, Optional, Sequence, Tuple, Union -from ..core import _config from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder -from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin -from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt from ..core.ops.special import Const -from ..core.tensor import amp -from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph -from ..jit import exclude_from_trace +from ..core.tensor.array_method import _matmul +from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default -from .debug_param import get_execution_strategy -from .elemwise import clip, minimum -from .tensor import broadcast_to, concat, expand_dims, squeeze +from .elemwise import clip +from .tensor import expand_dims, squeeze __all__ = [ "argmax", @@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: return result -class _Hashable: - def __init__(self, value) -> None: - self.value = value - - def __hash__(self) -> int: - return hash(str(self.value)) - - def __eq__(self, o: object) -> bool: - if not isinstance(o, _Hashable): - return False - return self.value == o.value - - -@lru_cache(maxsize=None) -def _get_extentedMatrixMulOp( - device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, -): - @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) - def extentedMatrixMulOp(inputs, f, c): - assert len(inputs) == 2 - inp1, inp2 = inputs - _dim1, _dim2 = dim1, dim2 - - def build_shape_head(shape, idx=-1): - # shape[:idx] - return f( - builtin.Subtensor(items=[[0, False, True, False, False]]), - shape, - c(idx, "int32"), - ) - - def build_shape_tail(shape, idx=-1): - # shape[idx:] - return f( - builtin.Subtensor(items=[[0, True, False, False, False]]), - shape, - c(idx, "int32"), - ) - - remove_row, remove_col = False, False - if _dim1 == 1: - _dim1 = 2 - remove_row = True - if _dim2 == 1: - _dim2 = 2 - remove_col = True - - if remove_row: - inp1 = f(builtin.AddAxis(axis=[0,]), inp1) - if remove_col: - inp2 = f(builtin.AddAxis(axis=[1,]), inp2) - - shape1 = f(GetVarShape(), inp1) - shape2 = f(GetVarShape(), inp2) - if _dim1 > 2: - inp1 = f( - builtin.Reshape(), - inp1, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), - build_shape_tail(shape1), - ), - ) - if _dim2 > 2: - inp2 = f( - builtin.Reshape(), - inp2, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), - build_shape_tail(shape2), - ), - ) - op = builtin.MatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=strategy.value, - ) - result = f(op, inp1, inp2) - result_shape = f(GetVarShape(), result) - if _dim1 > 2: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape1), - build_shape_tail(result_shape), - ), - ) - if _dim2 > 2: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape2), - build_shape_tail(result_shape), - ), - ) - maxdim = _dim1 if _dim1 > _dim2 else _dim2 - if remove_row: - result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) - if remove_col: - result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) - return (result,), (True,) - - return extentedMatrixMulOp - - -@lru_cache(maxsize=None) -def _get_extentedBatchedMatrixMulOp( - device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, -): - @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) - def extentedBatchedMatrixMulOp(inputs, f, c): - assert len(inputs) == 2 - inp1, inp2 = inputs - _dim1, _dim2 = dim1, dim2 - - def build_shape_head(shape, idx=-2): - # shape[:idx] - return f( - builtin.Subtensor(items=[[0, False, True, False, False]]), - shape, - c(idx, "int32"), - ) - - def build_shape_tail(shape, idx=-2): - # shape[idx:] - return f( - builtin.Subtensor(items=[[0, True, False, False, False]]), - shape, - c(idx, "int32"), - ) - - remove_row, remove_col = False, False - if _dim1 == 1: - _dim1 = 2 - remove_row = True - if _dim2 == 1: - _dim2 = 2 - remove_col = True - - if remove_row: - inp1 = f(builtin.AddAxis(axis=[0,]), inp1) - if remove_col: - inp2 = f(builtin.AddAxis(axis=[1,]), inp2) - shape1 = f(GetVarShape(), inp1) - shape2 = f(GetVarShape(), inp2) - maxdim = _dim1 if _dim1 > _dim2 else _dim2 - if _dim1 > _dim2: - # broadcast - shape2 = f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] - shape2, - ) - inp2 = f(builtin.Broadcast(), inp2, shape2) - batch_shape = build_shape_head(shape1) - if _dim2 > _dim1: - # broadcast - shape1 = f( - builtin.Concat(axis=0, comp_node=device), - build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] - shape1, - ) - inp1 = f(builtin.Broadcast(), inp1, shape1) - batch_shape = build_shape_head(shape2) - if _dim1 == _dim2: - batch_shape = build_shape_head(shape1) - - # compress inputs to 3d - if maxdim > 3: - inp1 = f( - builtin.Reshape(), - inp1, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), batch_shape), - build_shape_tail(shape1), - ), - ) - inp2 = f( - builtin.Reshape(), - inp2, - f( - builtin.Concat(axis=0, comp_node=device), - f(builtin.Reduce(mode="product", axis=0), batch_shape), - build_shape_tail(shape2), - ), - ) - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=strategy.value, - ) - result = f(op, inp1, inp2) - - if maxdim > 3: - result = f( - builtin.Reshape(), - result, - f( - builtin.Concat(axis=0, comp_node=device), - batch_shape, - build_shape_tail(f(GetVarShape(), result)), - ), - ) - if remove_row: - result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) - if remove_col: - result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) - return (result,), (True,) - - return extentedBatchedMatrixMulOp - - def matmul( inp1: Tensor, inp2: Tensor, @@ -1067,50 +838,7 @@ def matmul( [[10. 13.] [28. 40.]] """ - if amp._enabled: - compute_mode = "float32" - inp1, inp2 = cast_tensors(inp1, inp2) - else: - dtype = dtype_promotion(inp1, inp2) - if inp1.dtype != dtype: - inp1 = inp1.astype(dtype) - if inp2.dtype != dtype: - inp2 = inp2.astype(dtype) - - dim1, dim2 = inp1.ndim, inp2.ndim - assert dim1 > 0 and dim2 > 0 - maxdim = dim1 if dim1 > dim2 else dim2 - compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) - if dim1 == 1 and dim2 == 1: # dispatch to Dot - return dot(inp1, inp2) - elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul - extentedMatrixMulOp = _get_extentedMatrixMulOp( - inp1.device, - inp1.dtype, - dim1, - dim2, - transpose_a, - transpose_b, - compute_mode, - format, - strategy=_Hashable(get_execution_strategy()), - ) - (result,) = apply(extentedMatrixMulOp(), inp1, inp2) - return result - else: # dispath to BatchedMatrixMul - extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( - inp1.device, - inp1.dtype, - dim1, - dim2, - transpose_a, - transpose_b, - compute_mode, - format, - strategy=_Hashable(get_execution_strategy()), - ) - (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) - return result + return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format) def dot(inp1: Tensor, inp2: Tensor) -> Tensor: diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 18dcada79..ebdbf8ed3 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): @pytest.mark.parametrize("is_varnode", [True, False]) -def test_matmul(is_varnode): +@pytest.mark.parametrize( + "shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),], +) +def test_matmul(is_varnode, shape_a, shape_b): if is_varnode: network = Network() else: network = None - A = make_tensor(np.random.rand(5, 7).astype("float32"), network) - B = make_tensor(np.random.rand(7, 10).astype("float32"), network) + A = make_tensor(np.random.rand(*shape_a).astype("float32"), network) + B = make_tensor(np.random.rand(*shape_b).astype("float32"), network) C = A @ B if is_varnode: np.testing.assert_almost_equal( -- GitLab