提交 ca2deebc 编写于 作者: M Megvii Engine Team

fix(imperative/tensor): make @ operator has the same functionality as matmul functional

GitOrigin-RevId: bf6136cc1a7b5cc00103fd0eee27cb8bca8c6f99
上级 e860a083
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc import abc
import collections import collections
from functools import lru_cache
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -24,8 +25,8 @@ from .utils import ( ...@@ -24,8 +25,8 @@ from .utils import (
astype, astype,
cast_tensors, cast_tensors,
convert_inputs, convert_inputs,
isscalar,
make_shape_tuple, make_shape_tuple,
subgraph,
) )
_ElwMod = builtin.Elemwise.Mode _ElwMod = builtin.Elemwise.Mode
...@@ -73,23 +74,292 @@ def _elwise(*args, mode): ...@@ -73,23 +74,292 @@ def _elwise(*args, mode):
return _elwise_apply(args, mode) return _elwise_apply(args, mode)
def _matmul(inp1, inp2): @lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
result_shape = f(builtin.GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedMatrixMulOp
@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(builtin.GetVarShape(), inp1)
shape2 = f(builtin.GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)
# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(builtin.GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedBatchedMatrixMulOp
class _Hashable:
def __init__(self, value) -> None:
self.value = value
def __hash__(self) -> int:
return hash(str(self.value))
def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value
def _matmul(
inp1,
inp2,
transpose_a=False,
transpose_b=False,
compute_mode="default",
format="default",
):
if amp._enabled: if amp._enabled:
compute_mode = "float32" compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2) inp1, inp2 = cast_tensors(inp1, inp2)
else: else:
compute_mode = "default"
dtype = dtype_promotion(inp1, inp2) dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype: if inp1.dtype != dtype:
inp1 = inp1.astype(dtype) inp1 = inp1.astype(dtype)
if inp2.dtype != dtype: if inp2.dtype != dtype:
inp2 = inp2.astype(dtype) inp2 = inp2.astype(dtype)
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode=compute_mode, format="default" Strategy = builtin.ops.MatrixMul.Strategy
) strategy = Strategy(0)
(result,) = apply(op, inp1, inp2) if _config._benchmark_kernel:
return result strategy |= Strategy.PROFILE
else:
strategy |= Strategy.HEURISTIC
if _config._deterministic_kernel:
strategy |= Strategy.REPRODUCIBLE
if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(strategy),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
def _transpose(data, axes): def _transpose(data, axes):
......
...@@ -8,24 +8,18 @@ ...@@ -8,24 +8,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import math import math
from functools import lru_cache
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp from ..core.tensor.array_method import _matmul
from ..core.tensor.utils import _normalize_axis, cast_tensors, subgraph from ..core.tensor.utils import _normalize_axis
from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
from .debug_param import get_execution_strategy from .elemwise import clip
from .elemwise import clip, minimum from .tensor import expand_dims, squeeze
from .tensor import broadcast_to, concat, expand_dims, squeeze
__all__ = [ __all__ = [
"argmax", "argmax",
...@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor: ...@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor:
return result return result
class _Hashable:
def __init__(self, value) -> None:
self.value = value
def __hash__(self) -> int:
return hash(str(self.value))
def __eq__(self, o: object) -> bool:
if not isinstance(o, _Hashable):
return False
return self.value == o.value
@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
result_shape = f(GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedMatrixMulOp
@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)
# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy.value,
)
result = f(op, inp1, inp2)
if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedBatchedMatrixMulOp
def matmul( def matmul(
inp1: Tensor, inp1: Tensor,
inp2: Tensor, inp2: Tensor,
...@@ -1067,50 +838,7 @@ def matmul( ...@@ -1067,50 +838,7 @@ def matmul(
[[10. 13.] [[10. 13.]
[28. 40.]] [28. 40.]]
""" """
if amp._enabled: return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format)
compute_mode = "float32"
inp1, inp2 = cast_tensors(inp1, inp2)
else:
dtype = dtype_promotion(inp1, inp2)
if inp1.dtype != dtype:
inp1 = inp1.astype(dtype)
if inp2.dtype != dtype:
inp2 = inp2.astype(dtype)
dim1, dim2 = inp1.ndim, inp2.ndim
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
if dim1 == 1 and dim2 == 1: # dispatch to Dot
return dot(inp1, inp2)
elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(get_execution_strategy()),
)
(result,) = apply(extentedMatrixMulOp(), inp1, inp2)
return result
else: # dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
inp1.device,
inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=_Hashable(get_execution_strategy()),
)
(result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
return result
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
......
...@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode): ...@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode):
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_matmul(is_varnode): @pytest.mark.parametrize(
"shape_a, shape_b", [((4,), (4,)), ((10, 4), (4, 10)), ((3, 10, 4), (3, 4, 10)),],
)
def test_matmul(is_varnode, shape_a, shape_b):
if is_varnode: if is_varnode:
network = Network() network = Network()
else: else:
network = None network = None
A = make_tensor(np.random.rand(5, 7).astype("float32"), network) A = make_tensor(np.random.rand(*shape_a).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network) B = make_tensor(np.random.rand(*shape_b).astype("float32"), network)
C = A @ B C = A @ B
if is_varnode: if is_varnode:
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册