提交 3206af9d 编写于 作者: M Megvii Engine Team

perf(functional/matmul): reimplement matmul with subgraph

GitOrigin-RevId: 456b2a51d35852152c46d31548dac96f977d5b41
上级 8c47c1f1
...@@ -8,17 +8,21 @@ ...@@ -8,17 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import math import math
from functools import lru_cache
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .elemwise import clip from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze from .tensor import broadcast_to, concat, expand_dims, squeeze
__all__ = [ __all__ = [
...@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor: ...@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor:
return result return result
@lru_cache(maxsize=None)
def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3)
def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-1):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-1):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
if _dim1 > 2:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
build_shape_tail(shape1),
),
)
if _dim2 > 2:
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
build_shape_tail(shape2),
),
)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy,
)
result = f(op, inp1, inp2)
result_shape = f(GetVarShape(), result)
if _dim1 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1),
build_shape_tail(result_shape),
),
)
if _dim2 > 2:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2),
build_shape_tail(result_shape),
),
)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedMatrixMulOp
@lru_cache(maxsize=None)
def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3)
def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2
inp1, inp2 = inputs
_dim1, _dim2 = dim1, dim2
def build_shape_head(shape, idx=-2):
# shape[:idx]
return f(
builtin.Subtensor(items=[[0, False, True, False, False]]),
shape,
c(idx, "int32"),
)
def build_shape_tail(shape, idx=-2):
# shape[idx:]
return f(
builtin.Subtensor(items=[[0, True, False, False, False]]),
shape,
c(idx, "int32"),
)
remove_row, remove_col = False, False
if _dim1 == 1:
_dim1 = 2
remove_row = True
if _dim2 == 1:
_dim2 = 2
remove_col = True
if remove_row:
inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
if remove_col:
inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
shape1 = f(GetVarShape(), inp1)
shape2 = f(GetVarShape(), inp2)
maxdim = _dim1 if _dim1 > _dim2 else _dim2
if _dim1 > _dim2:
# broadcast
shape2 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
shape2,
)
inp2 = f(builtin.Broadcast(), inp2, shape2)
batch_shape = build_shape_head(shape1)
if _dim2 > _dim1:
# broadcast
shape1 = f(
builtin.Concat(axis=0, comp_node=device),
build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
shape1,
)
inp1 = f(builtin.Broadcast(), inp1, shape1)
batch_shape = build_shape_head(shape2)
if _dim1 == _dim2:
batch_shape = build_shape_head(shape1)
# compress inputs to 3d
if maxdim > 3:
inp1 = f(
builtin.Reshape(),
inp1,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape1),
),
)
inp2 = f(
builtin.Reshape(),
inp2,
f(
builtin.Concat(axis=0, comp_node=device),
f(builtin.Reduce(mode="product", axis=0), batch_shape),
build_shape_tail(shape2),
),
)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=strategy,
)
result = f(op, inp1, inp2)
if maxdim > 3:
result = f(
builtin.Reshape(),
result,
f(
builtin.Concat(axis=0, comp_node=device),
batch_shape,
build_shape_tail(f(GetVarShape(), result)),
),
)
if remove_row:
result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
if remove_col:
result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
return (result,), (True,)
return extentedBatchedMatrixMulOp
def matmul( def matmul(
inp1: Tensor, inp1: Tensor,
inp2: Tensor, inp2: Tensor,
...@@ -822,85 +1036,39 @@ def matmul( ...@@ -822,85 +1036,39 @@ def matmul(
if inp2.dtype != dtype: if inp2.dtype != dtype:
inp2 = inp2.astype(dtype) inp2 = inp2.astype(dtype)
remove_row, remove_col = False, False
dim1, dim2 = inp1.ndim, inp2.ndim dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication assert dim1 > 0 and dim2 > 0
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = inp1.shape
shape2 = inp2.shape
maxdim = dim1 if dim1 > dim2 else dim2 maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3: if dim1 == 1 and dim2 == 1: # dispatch to Dot
if use_symbolic_shape(): return dot(inp1, inp2)
if dim1 > dim2: elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
shape2 = concat([shape1[:-2], shape2[-2:]]) extentedMatrixMulOp = _get_extentedMatrixMulOp(
inp2 = broadcast_to(inp2, shape2) inp1.device,
if dim1 < dim2: inp1.dtype,
shape1 = concat([shape2[:-2], shape1[-2:]]) dim1,
inp1 = broadcast_to(inp1, shape1) dim2,
if maxdim > 3: transpose_a,
batch_shape = shape1[:-2] transpose_b,
# compress inputs to 3d compute_mode,
(inp1,) = apply( format,
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
) )
else: (result,) = apply(extentedMatrixMulOp, inp1, inp2)
op = builtin.MatrixMul( return result
transposeA=transpose_a, else: # dispath to BatchedMatrixMul
transposeB=transpose_b, extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
compute_mode=compute_mode, inp1.device,
format=format, inp1.dtype,
dim1,
dim2,
transpose_a,
transpose_b,
compute_mode,
format,
strategy=get_execution_strategy(), strategy=get_execution_strategy(),
) )
(result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2)
(result,) = apply(op, inp1, inp2) return result
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册